diff --git a/compiler/noirc_evaluator/src/ssa/opt/pure.rs b/compiler/noirc_evaluator/src/ssa/opt/pure.rs index 4321a441338..68d99bcb1bd 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/pure.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/pure.rs @@ -1,4 +1,3 @@ -use std::collections::BTreeSet; use std::sync::Arc; use fxhash::FxHashMap as HashMap; @@ -28,20 +27,15 @@ impl Ssa { /// identified as calling known pure functions. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn purity_analysis(mut self) -> Ssa { - let mut purities = HashMap::default(); - let mut called_functions = HashMap::default(); - // First look through each function to get a baseline on its purity and collect // the functions it calls to build a call graph. - for function in self.functions.values() { - let (purity, dependencies) = function.is_pure(); - purities.insert(function.id(), purity); - called_functions.insert(function.id(), dependencies); - } + let purities: HashMap<_, _> = + self.functions.values().map(|function| (function.id(), function.is_pure())).collect(); // Then transitively 'infect' any functions which call impure functions as also // impure. - let purities = analyze_call_graph(called_functions, purities, self.main_id); + let call_graph = CallGraph::from_ssa(&self); + let purities = analyze_call_graph(call_graph, purities); let purities = Arc::new(purities); // We're done, now store purities somewhere every dfg can find it. @@ -117,18 +111,16 @@ impl std::fmt::Display for Purity { } impl Function { - fn is_pure(&self) -> (Purity, BTreeSet) { - // Note, this function must be allowed to complete despite the fact that once the function is marked as impure - // then its final purity is known. This is because we need to collect all of the dependencies of the function - // to ensure that they are processed. - // - // This can be relaxed if we calculate the callgraph separately. - + fn is_pure(&self) -> Purity { let contains_reference = |value_id: &ValueId| { let typ = self.dfg.type_of_value(*value_id); typ.contains_reference() }; + if self.parameters().iter().any(&contains_reference) { + return Purity::Impure; + } + let mut result = if self.runtime().is_acir() { Purity::Pure } else { @@ -137,17 +129,6 @@ impl Function { Purity::PureWithPredicate }; - // Set of functions we call which the purity result depends on. - // `is_pure` is intended to be called on each function, building - // up a call graph of sorts to check afterwards to propagate impurity - // from called functions to their callers. Therefore, an initial "Pure" - // result here could be overridden by one of these dependencies being impure. - let mut dependencies = BTreeSet::new(); - - if self.parameters().iter().any(&contains_reference) { - result = Purity::Impure; - } - for block in self.reachable_blocks() { for instruction in self.dfg[block].instructions() { // We don't defer to Instruction::can_be_deduplicated, Instruction::requires_acir_gen_predicate, @@ -156,10 +137,10 @@ impl Function { // parameters or returned, we can ignore them. // We even ignore Constrain instructions. As long as the external parameters are // identical, we should be constraining the same values anyway. - let instruction_purity = match &self.dfg[*instruction] { + match &self.dfg[*instruction] { Instruction::Constrain(..) | Instruction::ConstrainNotEqual(..) - | Instruction::RangeCheck { .. } => Purity::PureWithPredicate, + | Instruction::RangeCheck { .. } => result = Purity::PureWithPredicate, // These instructions may be pure unless: // - We may divide by zero @@ -170,35 +151,35 @@ impl Function { | Instruction::ArrayGet { .. } | Instruction::ArraySet { .. }) => { if ins.requires_acir_gen_predicate(&self.dfg) { - Purity::PureWithPredicate - } else { - result + result = Purity::PureWithPredicate; } } Instruction::Call { func, .. } => { match &self.dfg[*func] { - Value::Function(function_id) => { + Value::Function(_) => { // We don't know if this function is pure or not yet, - // so track it as a dependency for now. - dependencies.insert(*function_id); - result + // + // `is_pure` is intended to be called on each function, building + // up a call graph of sorts to check afterwards to propagate impurity + // from called functions to their callers. Therefore, an initial "Pure" + // result here could be overridden by one of these dependencies being impure. } Value::Intrinsic(intrinsic) => match intrinsic.purity() { - Purity::Pure => result, - Purity::PureWithPredicate => Purity::PureWithPredicate, - Purity::Impure => Purity::Impure, + Purity::Pure => (), + Purity::PureWithPredicate => result = Purity::PureWithPredicate, + Purity::Impure => return Purity::Impure, }, - Value::ForeignFunction(_) => Purity::Impure, + Value::ForeignFunction(_) => return Purity::Impure, // The function we're calling is unknown in the remaining cases, // so just assume the worst. Value::Global(_) | Value::Instruction { .. } | Value::Param { .. } - | Value::NumericConstant { .. } => Purity::Impure, + | Value::NumericConstant { .. } => return Purity::Impure, } } - // The rest are always pure (including allocate, load, & store) and so don't affect purity + // The rest are always pure (including allocate, load, & store) Instruction::Cast(_, _) | Instruction::Not(_) | Instruction::Truncate { .. } @@ -210,67 +191,78 @@ impl Function { | Instruction::DecrementRc { .. } | Instruction::IfElse { .. } | Instruction::MakeArray { .. } - | Instruction::Noop => result, + | Instruction::Noop => (), }; - - result = result.unify(instruction_purity); } // If the function returns a reference it is impure let terminator = self.dfg[block].terminator(); if let Some(TerminatorInstruction::Return { return_values, .. }) = terminator { if return_values.iter().any(&contains_reference) { - result = Purity::Impure; + return Purity::Impure; } } } - (result, dependencies) + result } } fn analyze_call_graph( - dependencies: HashMap>, + call_graph: CallGraph, starting_purities: FunctionPurities, - main: FunctionId, ) -> FunctionPurities { - let call_graph = CallGraph::from_deps(dependencies); - // Now we can analyze it: a function is only as pure as all of // its called functions - let main_index = call_graph.ids_to_indices()[&main]; - let graph = call_graph.graph(); - let mut dfs = DfsPostOrder::new(graph, main_index); + let times_called = call_graph.times_called(); + let starting_points = + times_called.iter().filter_map(|(id, times_called)| (*times_called == 0).then_some(*id)); // The `starting_purities` are the preliminary results from `is_pure` // that don't take into account function calls. These finished purities do. let mut finished_purities = HashMap::default(); + let graph = call_graph.graph(); + let ids_to_indices = call_graph.ids_to_indices(); let indices_to_ids = call_graph.indices_to_ids(); - while let Some(index) = dfs.next(graph) { - let id = indices_to_ids[&index]; - let mut purity = starting_purities[&id]; - - for neighbor_index in graph.neighbors(index) { - let neighbor = indices_to_ids[&neighbor_index]; - - let neighbor_purity = finished_purities.get(&neighbor).copied().unwrap_or({ - // The dependent function isn't finished yet. Since we're following - // calls in a DFS, this means there are mutually recursive functions. - // We could handle these but would need a different, much slower algorithm - // to detect strongly connected components. Instead, since this should be - // a rare case, we bail and assume impure for now. - if neighbor == id { - // If the recursive call is to the same function we can ignore it - purity - } else { - Purity::Impure - } - }); - purity = purity.unify(neighbor_purity); + + for start_point in starting_points { + let start_index = ids_to_indices[&start_point]; + let mut dfs = DfsPostOrder::new(graph, start_index); + + while let Some(index) = dfs.next(graph) { + let id = indices_to_ids[&index]; + let mut purity = starting_purities[&id]; + + for neighbor_index in graph.neighbors(index) { + let neighbor = indices_to_ids[&neighbor_index]; + + let neighbor_purity = finished_purities.get(&neighbor).copied().unwrap_or({ + // The dependent function isn't finished yet. Since we're following + // calls in a DFS, this means there are mutually recursive functions. + // We could handle these but would need a different, much slower algorithm + // to detect strongly connected components. Instead, since this should be + // a rare case, we bail and assume impure for now. + if neighbor == id { + // If the recursive call is to the same function we can ignore it + purity + } else { + Purity::Impure + } + }); + purity = purity.unify(neighbor_purity); + } + + finished_purities.insert(id, purity); } + } - finished_purities.insert(id, purity); + // Any remaining functions are completely unreachable and are either recursive or mutually recursive. + // As these functions will be removed from the program, we treat them as impure. + let unhandled_funcs: Vec<_> = + starting_purities.keys().filter(|func| !finished_purities.contains_key(*func)).collect(); + for id in unhandled_funcs { + finished_purities.insert(*id, Purity::Impure); } finished_purities @@ -466,4 +458,25 @@ mod test { assert_eq!(purities[&FunctionId::test_new(2)], Purity::Impure); assert_eq!(purities[&FunctionId::test_new(3)], Purity::PureWithPredicate); } + + #[test] + fn handles_unreachable_functions() { + // Regression test for https://github.com/noir-lang/noir/issues/8666 + let src = r#" + brillig(inline) fn main f0 { + b0(): + return + } + brillig(inline) fn func_1 f1 { + b0(): + return + }"#; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.purity_analysis(); + + let purities = &ssa.main().dfg.function_purities; + assert_eq!(purities[&FunctionId::test_new(0)], Purity::PureWithPredicate); + assert_eq!(purities[&FunctionId::test_new(1)], Purity::PureWithPredicate); + } }