Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
344 changes: 336 additions & 8 deletions compiler/noirc_evaluator/src/ssa/ir/dfg/simplify.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::cmp::Ordering;

use crate::ssa::ir::{
basic_block::BasicBlockId,
dfg::simplify::value_merger::ValueMerger,
instruction::{
Binary, BinaryOp, ConstrainError, Instruction,
Binary, BinaryOp, ConstrainError, Instruction, Intrinsic,
binary::{truncate, truncate_field},
},
types::{NumericType, Type},
Expand Down Expand Up @@ -380,20 +382,17 @@ fn optimize_length_one_array_read(
/// - If the array value is constant, we use that array.
/// - If the array value is from a previous array-set, we recur.
/// - If the array value is from an array parameter, we use that array.
/// - If the array value is the result of a slice intrinsic, we try to use the slice,
/// adjusting the index as necessary.
///
/// That is, we have multiple `array_set` instructions setting various constant indexes
/// of the same array, returning a modified version. We want to go backwards until we
/// find the last `array_set` for the index we are interested in, and return the value set.
fn try_optimize_array_get_from_previous_set(
dfg: &mut DataFlowGraph,
mut array_id: ValueId,
target_index: FieldElement,
mut target_index: FieldElement,
) -> SimplifyResult {
// The target index must be less than the maximum array length
let Some(target_index_u32) = target_index.try_to_u32() else {
return SimplifyResult::None;
};

// Arbitrary number of maximum tries just to prevent this optimization from taking too long.
let max_tries = 5;
for _ in 0..max_tries {
Expand All @@ -410,14 +409,139 @@ fn try_optimize_array_get_from_previous_set(
}
}
Instruction::MakeArray { elements: array, typ: _ } => {
let index = target_index_u32 as usize;
let Some(target_index) = target_index.try_to_u32() else {
return SimplifyResult::None;
};

let index = target_index as usize;
if index < array.len() {
return SimplifyResult::SimplifiedTo(array[index]);
}
}
Instruction::Call { func, arguments } => {
if let Some(as_slice) = dfg.get_intrinsic(Intrinsic::AsSlice) {
if func == as_slice {
array_id = arguments[0];
continue;
}
}

if let Some(slice_insert) = dfg.get_intrinsic(Intrinsic::SliceInsert) {
// Only simplify when a single value is pushed
if func == slice_insert && arguments.len() == 4 {
let slice = arguments[1];
let insert_index = arguments[2];
let insert_value = arguments[3];
if let Some(insert_index) = dfg.get_numeric_constant(insert_index) {
match target_index.cmp(&insert_index) {
Ordering::Less => {
array_id = slice;
continue;
}
Ordering::Equal => {
return SimplifyResult::SimplifiedTo(insert_value);
}
Ordering::Greater => {
if !target_index.is_zero() {
array_id = slice;
target_index -= FieldElement::one();
continue;
}
}
}
}
}
}

if let Some(slice_remove) = dfg.get_intrinsic(Intrinsic::SliceRemove) {
if func == slice_remove {
let slice = arguments[1];
let remove_index = arguments[2];
if let Some(remove_index) = dfg.get_numeric_constant(remove_index) {
// Only optimize for non-composite slices
if dfg.type_of_value(slice).element_size() == 1 {
match target_index.cmp(&remove_index) {
Ordering::Less => {
array_id = slice;
continue;
}
Ordering::Equal | Ordering::Greater => {
if !target_index.is_zero() {
array_id = slice;
target_index += FieldElement::one();
continue;
}
}
}
}
}
}
}

if let Some(slice_push_front) = dfg.get_intrinsic(Intrinsic::SlicePushFront) {
// Only simplify when a single value is pushed
if func == slice_push_front && arguments.len() == 3 {
let slice = arguments[1];
let pushed_value = arguments[2];
if target_index.is_zero() {
return SimplifyResult::SimplifiedTo(pushed_value);
} else {
array_id = slice;
target_index -= FieldElement::one();
continue;
}
}
}

if let Some(slice_pop_front) = dfg.get_intrinsic(Intrinsic::SlicePopFront) {
if func == slice_pop_front {
let slice = arguments[1];
// Only optimize for non-composite slices
if dfg.type_of_value(slice).element_size() == 1 {
array_id = slice;
target_index += FieldElement::one();
continue;
}
}
}

if let Some(slice_push_back) = dfg.get_intrinsic(Intrinsic::SlicePushBack) {
// Only simplify when a single value is pushed
if func == slice_push_back {
let length = arguments[0];
let slice = arguments[1];
let pushed_value = arguments[2];
// Only optimize if the length is known
if let Some(length) = dfg.get_numeric_constant(length) {
if target_index == length {
return SimplifyResult::SimplifiedTo(pushed_value);
} else {
array_id = slice;
continue;
}
}
}
}

if let Some(slice_pop_back) = dfg.get_intrinsic(Intrinsic::SlicePopBack) {
if func == slice_pop_back {
let slice = arguments[1];
// Only optimize for non-composite slices
if dfg.type_of_value(slice).element_size() == 1 {
array_id = slice;
target_index += FieldElement::one();
continue;
}
}
}
}
_ => (),
}
} else if let Value::Param { typ: Type::Array(_, length), .. } = &dfg[array_id] {
let Some(target_index_u32) = target_index.try_to_u32() else {
return SimplifyResult::None;
};

if target_index_u32 < *length {
let index = dfg.make_constant(target_index, NumericType::length_type());
return SimplifyResult::SimplifiedToInstruction(Instruction::ArrayGet {
Expand Down Expand Up @@ -775,4 +899,208 @@ mod tests {
let ssa = Ssa::from_str_simplifying(src).unwrap();
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn simplifies_array_get_on_slice_insert() {
let src = "
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 3]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v8, v9 = call slice_insert(u32 3, v3, u32 1, Field 10) -> (u32, [Field])
v10 = array_get v9, index u32 0 -> Field
v11 = array_get v9, index u32 1 -> Field
v12 = array_get v9, index u32 2 -> Field
v13 = array_get v9, index u32 3 -> Field
return v10, v11, v12, v13
}
";
let ssa = Ssa::from_str_simplifying(src).unwrap();

// We can see that the array_gets now read from v0 instead of v9,
// and for the index 1 we use the inserted value.
// The get with index 3 isn't simplified because it's out of bounds.
assert_ssa_snapshot!(ssa, @r"
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 3]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v8, v9 = call slice_insert(u32 3, v3, u32 1, Field 10) -> (u32, [Field])
v11 = array_get v0, index u32 0 -> Field
v12 = array_get v0, index u32 1 -> Field
v14 = array_get v0, index u32 2 -> Field
return v11, Field 10, v12, v14
}
");
}

#[test]
fn simplifies_array_get_on_slice_remove() {
let src = "
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 4]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v8, v9 = call slice_remove(u32 4, v3, u32 1) -> (u32, [Field])
v10 = array_get v9, index u32 0 -> Field
v11 = array_get v9, index u32 1 -> Field
v12 = array_get v9, index u32 2 -> Field
v13 = array_get v9, index u32 3 -> Field
v14 = array_get v9, index u32 4 -> Field
return v10, v11, v12, v13, v14
}
";
let ssa = Ssa::from_str_simplifying(src).unwrap();

// We can see that the array_gets now read from v0 instead of v9.
// Indexes equal or greater than the removal index now get from
// the original array at `index + 1`.
// Indexes that are less that the removal index remain the same.
assert_ssa_snapshot!(ssa, @r"
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 4]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v7, v8 = call slice_remove(u32 4, v3, u32 1) -> (u32, [Field])
v10 = array_get v0, index u32 0 -> Field
v12 = array_get v0, index u32 2 -> Field
v14 = array_get v0, index u32 3 -> Field
v15 = array_get v8, index u32 3 -> Field
v16 = array_get v8, index u32 4 -> Field
return v10, v12, v14, v15, v16
}
");
}

#[test]
fn simplifies_array_get_on_slice_push_front() {
let src = "
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 3]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v8, v9 = call slice_push_front(u32 3, v3, Field 10) -> (u32, [Field])
v10 = array_get v9, index u32 0 -> Field
v11 = array_get v9, index u32 1 -> Field
v12 = array_get v9, index u32 2 -> Field
v13 = array_get v9, index u32 3 -> Field
return v10, v11, v12, v13
}
";
let ssa = Ssa::from_str_simplifying(src).unwrap();

// We can see that the array_gets now read from v0 instead of v9.
// If the index is zero, we use the pushed value.
// Otherwise, the new index is `index - 1` as in the original
// slice elements happen in previous positions.
assert_ssa_snapshot!(ssa, @r"
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 3]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v7, v8 = call slice_push_front(u32 3, v3, Field 10) -> (u32, [Field])
v10 = array_get v0, index u32 0 -> Field
v12 = array_get v0, index u32 1 -> Field
v14 = array_get v0, index u32 2 -> Field
return Field 10, v10, v12, v14
}
");
}

#[test]
fn simplifies_array_get_on_slice_pop_front() {
let src = "
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 4]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v7, v8, v9 = call slice_pop_front(u32 4, v3) -> (Field, u32, [Field])
v10 = array_get v9, index u32 0 -> Field
v11 = array_get v9, index u32 1 -> Field
v12 = array_get v9, index u32 2 -> Field
v13 = array_get v9, index u32 3 -> Field
v14 = array_get v9, index u32 4 -> Field
return v10, v11, v12, v13, v14
}
";
let ssa = Ssa::from_str_simplifying(src).unwrap();

// We can see that the array_gets now read from v0 instead of v9.
// Indexes equal or greater than the removal index now get from
// the original array at `index + 1`.
// Indexes that are less that the removal index remain the same.
assert_ssa_snapshot!(ssa, @r"
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 4]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v6, v7, v8 = call slice_pop_front(u32 4, v3) -> (Field, u32, [Field])
v10 = array_get v0, index u32 1 -> Field
v12 = array_get v0, index u32 2 -> Field
v14 = array_get v0, index u32 3 -> Field
v15 = array_get v8, index u32 3 -> Field
v16 = array_get v8, index u32 4 -> Field
return v10, v12, v14, v15, v16
}
");
}

#[test]
fn simplifies_array_get_on_slice_push_back() {
let src = "
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 3]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v8, v9 = call slice_push_back(u32 3, v3, Field 10) -> (u32, [Field])
v10 = array_get v9, index u32 0 -> Field
v11 = array_get v9, index u32 1 -> Field
v12 = array_get v9, index u32 2 -> Field
v13 = array_get v9, index u32 3 -> Field
v14 = array_get v9, index u32 4 -> Field
return v10, v11, v12, v13, v14
}
";
let ssa = Ssa::from_str_simplifying(src).unwrap();

// We can see that the array_gets now read from v0 instead of v9,
// When the index is the same as the slice length, we use the pushed value.
// Otherwise, the index remains the same.
assert_ssa_snapshot!(ssa, @r"
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 3]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v7, v8 = call slice_push_back(u32 3, v3, Field 10) -> (u32, [Field])
v10 = array_get v0, index u32 0 -> Field
v12 = array_get v0, index u32 1 -> Field
v14 = array_get v0, index u32 2 -> Field
v16 = array_get v8, index u32 4 -> Field
return v10, v12, v14, Field 10, v16
}
");
}

#[test]
fn simplifies_array_get_on_slice_pop_back() {
let src = "
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 4]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v7, v8, v9 = call slice_pop_back(u32 4, v3) -> (u32, [Field], Field)
v10 = array_get v8, index u32 0 -> Field
v11 = array_get v8, index u32 1 -> Field
v12 = array_get v8, index u32 2 -> Field
v13 = array_get v8, index u32 3 -> Field
v14 = array_get v8, index u32 4 -> Field
return v10, v11, v12, v13, v14
}
";
let ssa = Ssa::from_str_simplifying(src).unwrap();

// We can see that the array_gets now read from v0 instead of v9.
assert_ssa_snapshot!(ssa, @r"
acir(inline) predicate_pure fn main f0 {
b0(v0: [Field; 4]):
v2, v3 = call as_slice(v0) -> (u32, [Field])
v6, v7, v8 = call slice_pop_back(u32 4, v3) -> (u32, [Field], Field)
v10 = array_get v0, index u32 1 -> Field
v12 = array_get v0, index u32 2 -> Field
v14 = array_get v0, index u32 3 -> Field
v15 = array_get v7, index u32 3 -> Field
v16 = array_get v7, index u32 4 -> Field
return v10, v12, v14, v15, v16
}
");
}
}
Loading
Loading