diff --git a/compiler/noirc_evaluator/src/ssa/opt/array_set.rs b/compiler/noirc_evaluator/src/ssa/opt/array_set.rs index 8556403003e..ddc0c5c7b34 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/array_set.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/array_set.rs @@ -1,12 +1,26 @@ +//! The purpose of the `array_set_optimization` SSA pass is to mark `ArraySet` instructions +//! as mutable _iff_ the array is not potentially shared with the callers or callees of the +//! function and won't be used again in the function itself either. In other words, if this +//! is the last time we use this version of the array, we can mutate it in place, and avoid +//! having to make a copy of it. +//! +//! This optimization only applies to ACIR. In Brillig we use ref-counting to decide when +//! there are no other references to an array. +//! +//! The pass is expected to run at most once, and requires these passes to occur before itself: +//! * unrolling +//! * flattening +//! * removal of if-else instructions + +use core::panic; use std::mem; use crate::ssa::{ ir::{ basic_block::BasicBlockId, dfg::DataFlowGraph, - function::{Function, RuntimeType}, + function::Function, instruction::{Instruction, InstructionId, TerminatorInstruction}, - types::Type::{Array, Slice}, value::ValueId, }, ssa_gen::Ssa, @@ -14,9 +28,9 @@ use crate::ssa::{ use fxhash::{FxHashMap as HashMap, FxHashSet as HashSet}; impl Ssa { - /// Map arrays with the last instruction that uses it - /// For this we simply process all the instructions in execution order - /// and update the map whenever there is a match + /// Finds the last instruction that writes to an array and modifies it + /// to do an in-place mutation instead of making a copy if there are + /// no potential shared references to it. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn array_set_optimization(mut self) -> Self { for func in self.functions.values_mut() { @@ -35,15 +49,42 @@ impl Ssa { /// Pre-check condition for [Function::array_set_optimization]. /// /// Panics if: -/// - there already exists a mutable array set instruction. +/// - An ACIR function contains more than 1 block, i.e. it hasn't been flattened yet. +/// - There already exists a mutable array set instruction. +/// - There is an `IfElse` instruction which hasn't been removed yet. #[cfg(debug_assertions)] fn array_set_optimization_pre_check(func: &Function) { - // There should be no mutable array sets. - for block_id in func.reachable_blocks() { + // We only want to run this pass for ACIR. + if func.runtime().is_brillig() { + return; + } + + let reachable_blocks = func.reachable_blocks(); + + if !func.runtime().is_entry_point() { + assert_eq!( + reachable_blocks.len(), + 1, + "Expected there to be 1 block remaining in ACIR function for array_set optimization" + ); + } + + for block_id in reachable_blocks { let instruction_ids = func.dfg[block_id].instructions(); for instruction_id in instruction_ids { - if matches!(func.dfg[*instruction_id], Instruction::ArraySet { mutable: true, .. }) { - panic!("mutable ArraySet instruction exists before `array_set_optimization` pass"); + match func.dfg[*instruction_id] { + // There should be no mutable array sets. + Instruction::ArraySet { mutable: true, .. } => { + panic!( + "mutable ArraySet instruction exists before `array_set_optimization` pass" + ); + } + // The pass might mutate an array result of an `IfElse` and thus modify the input even if it's used later, + // so we assert that such instructions have already been removed by the `remove_if_else` pass. + Instruction::IfElse { .. } => { + panic!("IfElse instruction exists before `array_set_optimization` pass"); + } + _ => {} } } } @@ -52,7 +93,7 @@ fn array_set_optimization_pre_check(func: &Function) { /// Post-check condition for [Function::array_set_optimization]. /// /// Panics if: -/// - Mutable array_set optimization has been applied to Brillig function +/// - Mutable array_set optimization has been applied to Brillig function. #[cfg(debug_assertions)] fn array_set_optimization_post_check(func: &Function) { // Brillig functions should be not have any mutable array sets. @@ -71,28 +112,18 @@ fn array_set_optimization_post_check(func: &Function) { impl Function { pub(crate) fn array_set_optimization(&mut self) { - if matches!(self.runtime(), RuntimeType::Brillig(_)) { - // Brillig is supposed to use refcounting to decide whether to mutate an array; + if self.runtime().is_brillig() { + // Brillig is supposed to use ref-counting to decide whether to mutate an array; // array mutation was only meant for ACIR. We could use it with Brillig as well, // but then some of the optimizations that we can do in ACIR around shared // references have to be skipped, which makes it more cumbersome. return; } - let reachable_blocks = self.reachable_blocks(); - - if !self.runtime().is_entry_point() { - assert_eq!( - reachable_blocks.len(), - 1, - "Expected there to be 1 block remaining in Acir function for array_set optimization" - ); - } - let mut context = Context::new(&self.dfg); - for block in reachable_blocks.iter() { - context.analyze_last_uses(*block); + for block in self.reachable_blocks() { + context.analyze_last_uses(block); } let instructions_to_update = mem::take(&mut context.instructions_that_can_be_made_mutable); @@ -119,50 +150,62 @@ impl<'f> Context<'f> { } } - /// Builds the set of ArraySet instructions that can be made mutable + /// Remember this instruction as the last time the array has been read or written to. + /// + /// Any previous instruction marked to be made mutable needs to be cancelled, + /// as it turned out not to be the last use. + fn set_last_use(&mut self, array: ValueId, instruction_id: InstructionId) { + if let Some(existing) = self.array_to_last_use.insert(array, instruction_id) { + self.instructions_that_can_be_made_mutable.remove(&existing); + } + } + + /// Builds the set of `ArraySet` instructions that can be made mutable /// because their input value is unused elsewhere afterward. + /// + /// Only expected to execute on ACIR functions. fn analyze_last_uses(&mut self, block_id: BasicBlockId) { + assert!(self.dfg.runtime().is_acir()); + let block = &self.dfg[block_id]; + let terminator = self.dfg[block_id].unwrap_terminator(); + + // If we are in a return block we are not concerned about the array potentially being mutated again. + // In ACIR this should be the only kind of block we encounter, unless it's marked unreachable, + // in which case we don't need to optimize the array writes since we will end up with a failure anyway. + match terminator { + TerminatorInstruction::Return { .. } => {} + TerminatorInstruction::Unreachable { .. } => { + return; + } + other => { + panic!("unexpected terminator in ACIR: {other:?}") + } + }; + for instruction_id in block.instructions() { match &self.dfg[*instruction_id] { + // Reading an array constitutes as use, replacing any previous last use. Instruction::ArrayGet { array, .. } => { - let array = *array; - - if let Some(existing) = self.array_to_last_use.insert(array, *instruction_id) { - self.instructions_that_can_be_made_mutable.remove(&existing); - } + self.set_last_use(*array, *instruction_id); } + // Writing to an array is a use; mark it for mutation unless it might be shared. Instruction::ArraySet { array, .. } => { - let array = *array; - - if let Some(existing) = self.array_to_last_use.insert(array, *instruction_id) { - self.instructions_that_can_be_made_mutable.remove(&existing); - } + self.set_last_use(*array, *instruction_id); // If the array we are setting does not come from a load we can safely mark it mutable. // If the array comes from a load we may potentially being mutating an array at a reference // that is loaded from by other values. - let terminator = self.dfg[block_id].unwrap_terminator(); - - // If we are in a return block we are not concerned about the array potentially being mutated again. - let is_return_block = - matches!(terminator, TerminatorInstruction::Return { .. }); // We also want to check that the array is not part of the terminator arguments, as this means it is used again. let mut is_array_in_terminator = false; terminator.for_each_value(|value| { - // The terminator can contain original IDs, while the SSA has replaced the array value IDs; we need to resolve to compare. - if !is_array_in_terminator && value == array { - is_array_in_terminator = true; - } + is_array_in_terminator |= value == *array; }); - let can_mutate = if let Some(is_from_param) = self.arrays_from_load.get(&array) - { - // If the array was loaded from a reference parameter, we cannot - // safely mark that array mutable as it may be shared by another value. - !is_from_param && is_return_block + let can_mutate = if let Some(is_from_param) = self.arrays_from_load.get(array) { + !is_from_param } else { !is_array_in_terminator }; @@ -171,37 +214,36 @@ impl<'f> Context<'f> { self.instructions_that_can_be_made_mutable.insert(*instruction_id); } } + // Array arguments passed in calls constitute a use. Instruction::Call { arguments, .. } => { for argument in arguments { - if matches!(self.dfg.type_of_value(*argument), Array { .. } | Slice { .. }) - { - let argument = *argument; - - if let Some(existing) = - self.array_to_last_use.insert(argument, *instruction_id) - { - self.instructions_that_can_be_made_mutable.remove(&existing); - } + if self.dfg.type_of_value(*argument).is_array() { + self.set_last_use(*argument, *instruction_id); } } } + // Arrays loaded from input references might be shared with the caller. Instruction::Load { address } => { let result = self.dfg.instruction_results(*instruction_id)[0]; - if matches!(self.dfg.type_of_value(result), Array { .. } | Slice { .. }) { + if self.dfg.type_of_value(result).is_array() { let is_reference_param = self.dfg.block_parameters(block_id).contains(address); self.arrays_from_load.insert(result, is_reference_param); } } + // Arrays nested in other arrays are a use. Instruction::MakeArray { elements, .. } => { for element in elements { - if let Some(existing) = - self.array_to_last_use.insert(*element, *instruction_id) - { - self.instructions_that_can_be_made_mutable.remove(&existing); + if self.dfg.type_of_value(*element).is_array() { + self.set_last_use(*element, *instruction_id); } } } + Instruction::IfElse { .. } => { + panic!( + "IfElse instructions are assumed to be removed before array_set optimization" + ) + } _ => (), } } @@ -224,7 +266,7 @@ fn make_mutable(dfg: &mut DataFlowGraph, instructions_to_update: &HashSet &mut [Field; 1] + store v1 at v2 + v3 = load v2 -> [Field; 1] + v6 = array_set mut v3, index u32 0, value Field 1 + return + } + "); + } + + #[test] + fn does_not_mutate_arrays_in_unreachable_blocks() { + let src = " + acir(inline) fn main f0 { + b0(): + v1 = make_array [Field 0] : [Field; 1] + v2 = allocate -> &mut [Field; 1] + store v1 at v2 + v3 = load v2 -> [Field; 1] + v4 = array_set v3, index u32 0, value Field 1 + constrain u1 0 == u1 1 + unreachable + } + "; + assert_ssa_does_not_change(src, Ssa::array_set_optimization); + } + + // Demonstrate that we assume that `IfElse` instructions have been + // removed by previous passes. Otherwise we would need to handle transitive + // relations between arrays. + #[test] + #[should_panic] + fn assumes_no_if_else() { + // v4 can be v1 or v2. v1 is returned, so v4 should not be mutated. + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: u1, v1: [u32; 2], v2: [u32; 2]): + v3 = not v0 + v4 = if v0 then v1 else (if v3) v2 + v5 = array_set v4, index u32 0, value u32 1 + return v1 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let _ssa = ssa.array_set_optimization(); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index d73063fb783..7c50012773b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -118,3 +118,14 @@ macro_rules! assert_ssa_snapshot { insta::assert_snapshot!(ssa_string, $($arg)*) }; } + +/// Assert that running a certain pass on the SSA does nothing. +#[cfg(test)] +pub(crate) fn assert_ssa_does_not_change( + src: &str, + pass: impl FnOnce(crate::ssa::Ssa) -> crate::ssa::Ssa, +) { + let ssa = crate::ssa::Ssa::from_str(src).unwrap(); + let ssa = pass(ssa); + assert_normalized_ssa_equals(ssa, src); +}