diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index f4faffcb7ae..0e097a012ec 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -567,8 +567,10 @@ impl Loop { let mut unroll_into = self.get_pre_header(function, cfg)?; let mut jump_value = get_induction_variable(&function.dfg, unroll_into)?; - while let Some(context) = self.unroll_header(function, unroll_into, jump_value)? { - (unroll_into, jump_value) = context.unroll_loop_iteration(); + while let Some((context, loop_header_id)) = + self.unroll_header(function, unroll_into, jump_value)? + { + (unroll_into, jump_value) = context.unroll_loop_iteration(loop_header_id); } Ok(()) @@ -600,20 +602,21 @@ impl Loop { /// Unrolls the header block of the loop. This is the block that dominates all other blocks in the /// loop and contains the jmpif instruction that lets us know if we should continue looping. - /// Returns Some(iteration context) if we should perform another iteration. + /// Returns Some((iteration context, loop_header_id)) if we should perform another iteration. fn unroll_header<'a>( &'a self, function: &'a mut Function, unroll_into: BasicBlockId, induction_value: ValueId, - ) -> Result>, CallStack> { + ) -> Result, BasicBlockId)>, CallStack> { // We insert into a fresh block first and move instructions into the unroll_into block later // only once we verify the jmpif instruction has a constant condition. If it does not, we can // just discard this fresh block and leave the loop unmodified. let fresh_block = function.dfg.make_block(); let mut context = LoopIteration::new(function, self, fresh_block, self.header); - let source_block = &context.dfg()[context.source_block]; + let loop_header_id = context.source_block; + let source_block = &context.dfg()[loop_header_id]; assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header"); // Insert the current value of the loop induction variable into our context. @@ -621,6 +624,8 @@ impl Loop { context.inserter.try_map_value(first_param, induction_value); // Copy over all instructions and a fresh terminator. context.inline_instructions_from_block(); + context.visited_blocks.insert(loop_header_id); + // Mutate the terminator if possible so that it points at the iteration block. match context.dfg()[fresh_block].unwrap_terminator() { TerminatorInstruction::JmpIf { @@ -651,7 +656,11 @@ impl Loop { // have no more loops to unroll, because that block was not part of the loop itself, // ie. it wasn't between `loop_header` and `loop_body`. Otherwise we have the `loop_body` // in `source_block` and can unroll that into the destination. - Ok(self.blocks.contains(&context.source_block).then_some(context)) + Ok(self + .blocks + .contains(&context.source_block) + .then_some(context) + .map(|iteration_context| (iteration_context, loop_header_id))) } else { // If this case is reached the loop either uses non-constant indices or we need // another pass, such as mem2reg to resolve them to constants. @@ -967,6 +976,9 @@ struct LoopIteration<'f> { original_blocks: HashMap, visited_blocks: HashSet, + /// Has `unroll_loop_iteration` reached the `loop_header_id`? + encountered_loop_header: bool, + insert_block: BasicBlockId, source_block: BasicBlockId, @@ -992,6 +1004,8 @@ impl<'f> LoopIteration<'f> { blocks: HashMap::default(), original_blocks: HashMap::default(), visited_blocks: HashSet::default(), + encountered_loop_header: false, + induction_value: None, } } @@ -1002,13 +1016,14 @@ impl<'f> LoopIteration<'f> { /// It is expected the terminator instructions are set up to branch into an empty block /// for further unrolling. When the loop is finished this will need to be mutated to /// jump to the end of the loop instead. - fn unroll_loop_iteration(mut self) -> (BasicBlockId, ValueId) { + fn unroll_loop_iteration(mut self, loop_header_id: BasicBlockId) -> (BasicBlockId, ValueId) { // Kick off the unrolling from the initial source block. let mut next_blocks = self.unroll_loop_block(); while let Some(block) = next_blocks.pop() { self.insert_block = block; self.source_block = self.get_original_block(block); + self.encountered_loop_header |= loop_header_id == self.source_block; if !self.visited_blocks.contains(&self.source_block) { let mut blocks = self.unroll_loop_block(); @@ -1021,6 +1036,11 @@ impl<'f> LoopIteration<'f> { .induction_value .expect("Expected to find the induction variable by end of loop iteration"); + assert!( + self.encountered_loop_header, + "expected to encounter loop header when visiting blocks" + ); + (end_block, induction_value) }