diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify.rs index 38873b27043..325d7fafe32 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify.rs @@ -1,8 +1,10 @@ +use std::cmp::Ordering; + use crate::ssa::ir::{ basic_block::BasicBlockId, dfg::simplify::value_merger::ValueMerger, instruction::{ - Binary, BinaryOp, ConstrainError, Instruction, + Binary, BinaryOp, ConstrainError, Instruction, Intrinsic, binary::{truncate, truncate_field}, }, types::{NumericType, Type}, @@ -380,6 +382,8 @@ fn optimize_length_one_array_read( /// - If the array value is constant, we use that array. /// - If the array value is from a previous array-set, we recur. /// - If the array value is from an array parameter, we use that array. +/// - If the array value is the result of a slice intrinsic, we try to use the slice, +/// adjusting the index as necessary. /// /// That is, we have multiple `array_set` instructions setting various constant indexes /// of the same array, returning a modified version. We want to go backwards until we @@ -387,13 +391,8 @@ fn optimize_length_one_array_read( fn try_optimize_array_get_from_previous_set( dfg: &mut DataFlowGraph, mut array_id: ValueId, - target_index: FieldElement, + mut target_index: FieldElement, ) -> SimplifyResult { - // The target index must be less than the maximum array length - let Some(target_index_u32) = target_index.try_to_u32() else { - return SimplifyResult::None; - }; - // Arbitrary number of maximum tries just to prevent this optimization from taking too long. let max_tries = 5; for _ in 0..max_tries { @@ -410,14 +409,139 @@ fn try_optimize_array_get_from_previous_set( } } Instruction::MakeArray { elements: array, typ: _ } => { - let index = target_index_u32 as usize; + let Some(target_index) = target_index.try_to_u32() else { + return SimplifyResult::None; + }; + + let index = target_index as usize; if index < array.len() { return SimplifyResult::SimplifiedTo(array[index]); } } + Instruction::Call { func, arguments } => { + if let Some(as_slice) = dfg.get_intrinsic(Intrinsic::AsSlice) { + if func == as_slice { + array_id = arguments[0]; + continue; + } + } + + if let Some(slice_insert) = dfg.get_intrinsic(Intrinsic::SliceInsert) { + // Only simplify when a single value is pushed + if func == slice_insert && arguments.len() == 4 { + let slice = arguments[1]; + let insert_index = arguments[2]; + let insert_value = arguments[3]; + if let Some(insert_index) = dfg.get_numeric_constant(insert_index) { + match target_index.cmp(&insert_index) { + Ordering::Less => { + array_id = slice; + continue; + } + Ordering::Equal => { + return SimplifyResult::SimplifiedTo(insert_value); + } + Ordering::Greater => { + if !target_index.is_zero() { + array_id = slice; + target_index -= FieldElement::one(); + continue; + } + } + } + } + } + } + + if let Some(slice_remove) = dfg.get_intrinsic(Intrinsic::SliceRemove) { + if func == slice_remove { + let slice = arguments[1]; + let remove_index = arguments[2]; + if let Some(remove_index) = dfg.get_numeric_constant(remove_index) { + // Only optimize for non-composite slices + if dfg.type_of_value(slice).element_size() == 1 { + match target_index.cmp(&remove_index) { + Ordering::Less => { + array_id = slice; + continue; + } + Ordering::Equal | Ordering::Greater => { + if !target_index.is_zero() { + array_id = slice; + target_index += FieldElement::one(); + continue; + } + } + } + } + } + } + } + + if let Some(slice_push_front) = dfg.get_intrinsic(Intrinsic::SlicePushFront) { + // Only simplify when a single value is pushed + if func == slice_push_front && arguments.len() == 3 { + let slice = arguments[1]; + let pushed_value = arguments[2]; + if target_index.is_zero() { + return SimplifyResult::SimplifiedTo(pushed_value); + } else { + array_id = slice; + target_index -= FieldElement::one(); + continue; + } + } + } + + if let Some(slice_pop_front) = dfg.get_intrinsic(Intrinsic::SlicePopFront) { + if func == slice_pop_front { + let slice = arguments[1]; + // Only optimize for non-composite slices + if dfg.type_of_value(slice).element_size() == 1 { + array_id = slice; + target_index += FieldElement::one(); + continue; + } + } + } + + if let Some(slice_push_back) = dfg.get_intrinsic(Intrinsic::SlicePushBack) { + // Only simplify when a single value is pushed + if func == slice_push_back { + let length = arguments[0]; + let slice = arguments[1]; + let pushed_value = arguments[2]; + // Only optimize if the length is known + if let Some(length) = dfg.get_numeric_constant(length) { + if target_index == length { + return SimplifyResult::SimplifiedTo(pushed_value); + } else { + array_id = slice; + continue; + } + } + } + } + + if let Some(slice_pop_back) = dfg.get_intrinsic(Intrinsic::SlicePopBack) { + if func == slice_pop_back { + let slice = arguments[1]; + // Only optimize for non-composite slices + if dfg.type_of_value(slice).element_size() == 1 { + array_id = slice; + target_index += FieldElement::one(); + continue; + } + } + } + } _ => (), } } else if let Value::Param { typ: Type::Array(_, length), .. } = &dfg[array_id] { + let Some(target_index_u32) = target_index.try_to_u32() else { + return SimplifyResult::None; + }; + if target_index_u32 < *length { let index = dfg.make_constant(target_index, NumericType::length_type()); return SimplifyResult::SimplifiedToInstruction(Instruction::ArrayGet { @@ -775,4 +899,208 @@ mod tests { let ssa = Ssa::from_str_simplifying(src).unwrap(); assert_normalized_ssa_equals(ssa, src); } + + #[test] + fn simplifies_array_get_on_slice_insert() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 3]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v8, v9 = call slice_insert(u32 3, v3, u32 1, Field 10) -> (u32, [Field]) + v10 = array_get v9, index u32 0 -> Field + v11 = array_get v9, index u32 1 -> Field + v12 = array_get v9, index u32 2 -> Field + v13 = array_get v9, index u32 3 -> Field + return v10, v11, v12, v13 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + // We can see that the array_gets now read from v0 instead of v9, + // and for the index 1 we use the inserted value. + // The get with index 3 isn't simplified because it's out of bounds. + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 3]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v8, v9 = call slice_insert(u32 3, v3, u32 1, Field 10) -> (u32, [Field]) + v11 = array_get v0, index u32 0 -> Field + v12 = array_get v0, index u32 1 -> Field + v14 = array_get v0, index u32 2 -> Field + return v11, Field 10, v12, v14 + } + "); + } + + #[test] + fn simplifies_array_get_on_slice_remove() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 4]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v8, v9 = call slice_remove(u32 4, v3, u32 1) -> (u32, [Field]) + v10 = array_get v9, index u32 0 -> Field + v11 = array_get v9, index u32 1 -> Field + v12 = array_get v9, index u32 2 -> Field + v13 = array_get v9, index u32 3 -> Field + v14 = array_get v9, index u32 4 -> Field + return v10, v11, v12, v13, v14 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + // We can see that the array_gets now read from v0 instead of v9. + // Indexes equal or greater than the removal index now get from + // the original array at `index + 1`. + // Indexes that are less that the removal index remain the same. + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 4]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v7, v8 = call slice_remove(u32 4, v3, u32 1) -> (u32, [Field]) + v10 = array_get v0, index u32 0 -> Field + v12 = array_get v0, index u32 2 -> Field + v14 = array_get v0, index u32 3 -> Field + v15 = array_get v8, index u32 3 -> Field + v16 = array_get v8, index u32 4 -> Field + return v10, v12, v14, v15, v16 + } + "); + } + + #[test] + fn simplifies_array_get_on_slice_push_front() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 3]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v8, v9 = call slice_push_front(u32 3, v3, Field 10) -> (u32, [Field]) + v10 = array_get v9, index u32 0 -> Field + v11 = array_get v9, index u32 1 -> Field + v12 = array_get v9, index u32 2 -> Field + v13 = array_get v9, index u32 3 -> Field + return v10, v11, v12, v13 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + // We can see that the array_gets now read from v0 instead of v9. + // If the index is zero, we use the pushed value. + // Otherwise, the new index is `index - 1` as in the original + // slice elements happen in previous positions. + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 3]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v7, v8 = call slice_push_front(u32 3, v3, Field 10) -> (u32, [Field]) + v10 = array_get v0, index u32 0 -> Field + v12 = array_get v0, index u32 1 -> Field + v14 = array_get v0, index u32 2 -> Field + return Field 10, v10, v12, v14 + } + "); + } + + #[test] + fn simplifies_array_get_on_slice_pop_front() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 4]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v7, v8, v9 = call slice_pop_front(u32 4, v3) -> (Field, u32, [Field]) + v10 = array_get v9, index u32 0 -> Field + v11 = array_get v9, index u32 1 -> Field + v12 = array_get v9, index u32 2 -> Field + v13 = array_get v9, index u32 3 -> Field + v14 = array_get v9, index u32 4 -> Field + return v10, v11, v12, v13, v14 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + // We can see that the array_gets now read from v0 instead of v9. + // Indexes equal or greater than the removal index now get from + // the original array at `index + 1`. + // Indexes that are less that the removal index remain the same. + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 4]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v6, v7, v8 = call slice_pop_front(u32 4, v3) -> (Field, u32, [Field]) + v10 = array_get v0, index u32 1 -> Field + v12 = array_get v0, index u32 2 -> Field + v14 = array_get v0, index u32 3 -> Field + v15 = array_get v8, index u32 3 -> Field + v16 = array_get v8, index u32 4 -> Field + return v10, v12, v14, v15, v16 + } + "); + } + + #[test] + fn simplifies_array_get_on_slice_push_back() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 3]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v8, v9 = call slice_push_back(u32 3, v3, Field 10) -> (u32, [Field]) + v10 = array_get v9, index u32 0 -> Field + v11 = array_get v9, index u32 1 -> Field + v12 = array_get v9, index u32 2 -> Field + v13 = array_get v9, index u32 3 -> Field + v14 = array_get v9, index u32 4 -> Field + return v10, v11, v12, v13, v14 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + // We can see that the array_gets now read from v0 instead of v9, + // When the index is the same as the slice length, we use the pushed value. + // Otherwise, the index remains the same. + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 3]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v7, v8 = call slice_push_back(u32 3, v3, Field 10) -> (u32, [Field]) + v10 = array_get v0, index u32 0 -> Field + v12 = array_get v0, index u32 1 -> Field + v14 = array_get v0, index u32 2 -> Field + v16 = array_get v8, index u32 4 -> Field + return v10, v12, v14, Field 10, v16 + } + "); + } + + #[test] + fn simplifies_array_get_on_slice_pop_back() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 4]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v7, v8, v9 = call slice_pop_back(u32 4, v3) -> (u32, [Field], Field) + v10 = array_get v8, index u32 0 -> Field + v11 = array_get v8, index u32 1 -> Field + v12 = array_get v8, index u32 2 -> Field + v13 = array_get v8, index u32 3 -> Field + v14 = array_get v8, index u32 4 -> Field + return v10, v11, v12, v13, v14 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + // We can see that the array_gets now read from v0 instead of v9. + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 4]): + v2, v3 = call as_slice(v0) -> (u32, [Field]) + v6, v7, v8 = call slice_pop_back(u32 4, v3) -> (u32, [Field], Field) + v10 = array_get v0, index u32 1 -> Field + v12 = array_get v0, index u32 2 -> Field + v14 = array_get v0, index u32 3 -> Field + v15 = array_get v7, index u32 3 -> Field + v16 = array_get v7, index u32 4 -> Field + return v10, v12, v14, v15, v16 + } + "); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs index e76429b955b..a514d0b2bb0 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs @@ -669,11 +669,10 @@ mod tests { v8, v9 = call slice_push_front(v6, v3, v2) -> (u32, [Field]) v10 = not v0 v11 = cast v0 as u32 - v13 = array_get v9, index u32 0 -> Field - v14 = make_array [v13] : [Field] + v12 = make_array [v2] : [Field] enable_side_effects u1 1 - v17 = add v11, u32 1 - v18 = make_array [v2, v13] : [Field] + v15 = add v11, u32 1 + v16 = make_array [v2, v2] : [Field] constrain v2 == Field 1 return } @@ -721,11 +720,10 @@ mod tests { v11, v12 = call slice_push_front(v9, v6, v2) -> (u32, [Field]) v13 = not v0 v14 = cast v0 as u32 - v16 = array_get v12, index u32 0 -> Field - v17 = make_array [v16] : [Field] + v15 = make_array [v2] : [Field] enable_side_effects u1 1 - v20 = add v14, u32 1 - v21 = make_array [v2, v16] : [Field] + v18 = add v14, u32 1 + v19 = make_array [v2, v2] : [Field] constrain v2 == Field 1 return } @@ -769,11 +767,10 @@ mod tests { v9, v10 = call slice_insert(v6, v3, u32 0, v2) -> (u32, [Field]) v11 = not v0 v12 = cast v0 as u32 - v13 = array_get v10, index u32 0 -> Field - v14 = make_array [v13] : [Field] + v13 = make_array [v2] : [Field] enable_side_effects u1 1 - v17 = add v12, u32 1 - v18 = make_array [v2, v13] : [Field] + v16 = add v12, u32 1 + v17 = make_array [v2, v2] : [Field] constrain v2 == Field 1 return } @@ -818,17 +815,16 @@ mod tests { v10, v11, v12 = call slice_pop_back(v8, v5) -> (u32, [Field], Field) v13 = not v0 v14 = cast v0 as u32 - v16 = array_get v11, index u32 0 -> Field - v17 = cast v0 as Field - v18 = cast v13 as Field - v19 = mul v17, v16 - v20 = mul v18, Field 2 - v21 = add v19, v20 - v22 = make_array [v21, Field 3] : [Field] + v15 = cast v0 as Field + v16 = cast v13 as Field + v17 = mul v15, Field 3 + v18 = mul v16, Field 2 + v19 = add v17, v18 + v20 = make_array [v19, Field 3] : [Field] enable_side_effects u1 1 - v24, v25, v26 = call slice_pop_back(v14, v22) -> (u32, [Field], Field) - v27 = array_get v25, index u32 0 -> Field - constrain v27 == Field 1 + v22, v23, v24 = call slice_pop_back(v14, v20) -> (u32, [Field], Field) + v26 = array_get v23, index u32 0 -> Field + constrain v26 == Field 1 return } "); @@ -872,17 +868,16 @@ mod tests { v10, v11, v12 = call slice_pop_front(v8, v5) -> (Field, u32, [Field]) v13 = not v0 v14 = cast v0 as u32 - v16 = array_get v12, index u32 0 -> Field - v17 = cast v0 as Field - v18 = cast v13 as Field - v19 = mul v17, v16 - v20 = mul v18, Field 2 - v21 = add v19, v20 - v22 = make_array [v21, Field 3] : [Field] + v15 = cast v0 as Field + v16 = cast v13 as Field + v17 = mul v15, Field 3 + v18 = mul v16, Field 2 + v19 = add v17, v18 + v20 = make_array [v19, Field 3] : [Field] enable_side_effects u1 1 - v24, v25, v26 = call slice_pop_front(v14, v22) -> (Field, u32, [Field]) - v27 = array_get v26, index u32 0 -> Field - constrain v27 == Field 1 + v22, v23, v24 = call slice_pop_front(v14, v20) -> (Field, u32, [Field]) + v26 = array_get v24, index u32 0 -> Field + constrain v26 == Field 1 return } ");