Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions compiler/noirc_evaluator/src/ssa/opt/unrolling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,10 @@ impl Loop {
let mut unroll_into = self.get_pre_header(function, cfg)?;
let mut jump_value = get_induction_variable(&function.dfg, unroll_into)?;

while let Some(context) = self.unroll_header(function, unroll_into, jump_value)? {
(unroll_into, jump_value) = context.unroll_loop_iteration();
while let Some((context, loop_header_id)) =
self.unroll_header(function, unroll_into, jump_value)?
{
(unroll_into, jump_value) = context.unroll_loop_iteration(loop_header_id);
}

Ok(())
Expand Down Expand Up @@ -600,27 +602,30 @@ impl Loop {

/// Unrolls the header block of the loop. This is the block that dominates all other blocks in the
/// loop and contains the jmpif instruction that lets us know if we should continue looping.
/// Returns Some(iteration context) if we should perform another iteration.
/// Returns Some((iteration context, loop_header_id)) if we should perform another iteration.
fn unroll_header<'a>(
&'a self,
function: &'a mut Function,
unroll_into: BasicBlockId,
induction_value: ValueId,
) -> Result<Option<LoopIteration<'a>>, CallStack> {
) -> Result<Option<(LoopIteration<'a>, BasicBlockId)>, CallStack> {
// We insert into a fresh block first and move instructions into the unroll_into block later
// only once we verify the jmpif instruction has a constant condition. If it does not, we can
// just discard this fresh block and leave the loop unmodified.
let fresh_block = function.dfg.make_block();

let mut context = LoopIteration::new(function, self, fresh_block, self.header);
let source_block = &context.dfg()[context.source_block];
let loop_header_id = context.source_block;
let source_block = &context.dfg()[loop_header_id];
assert_eq!(source_block.parameters().len(), 1, "Expected only 1 argument in loop header");

// Insert the current value of the loop induction variable into our context.
let first_param = source_block.parameters()[0];
context.inserter.try_map_value(first_param, induction_value);
// Copy over all instructions and a fresh terminator.
context.inline_instructions_from_block();
context.visited_blocks.insert(loop_header_id);

// Mutate the terminator if possible so that it points at the iteration block.
match context.dfg()[fresh_block].unwrap_terminator() {
TerminatorInstruction::JmpIf {
Expand Down Expand Up @@ -651,7 +656,11 @@ impl Loop {
// have no more loops to unroll, because that block was not part of the loop itself,
// ie. it wasn't between `loop_header` and `loop_body`. Otherwise we have the `loop_body`
// in `source_block` and can unroll that into the destination.
Ok(self.blocks.contains(&context.source_block).then_some(context))
Ok(self
.blocks
.contains(&context.source_block)
.then_some(context)
.map(|iteration_context| (iteration_context, loop_header_id)))
} else {
// If this case is reached the loop either uses non-constant indices or we need
// another pass, such as mem2reg to resolve them to constants.
Expand Down Expand Up @@ -967,6 +976,9 @@ struct LoopIteration<'f> {
original_blocks: HashMap<BasicBlockId, BasicBlockId>,
visited_blocks: HashSet<BasicBlockId>,

/// Has `unroll_loop_iteration` reached the `loop_header_id`?
encountered_loop_header: bool,

insert_block: BasicBlockId,
source_block: BasicBlockId,

Expand All @@ -992,6 +1004,8 @@ impl<'f> LoopIteration<'f> {
blocks: HashMap::default(),
original_blocks: HashMap::default(),
visited_blocks: HashSet::default(),
encountered_loop_header: false,

induction_value: None,
}
}
Expand All @@ -1002,13 +1016,14 @@ impl<'f> LoopIteration<'f> {
/// It is expected the terminator instructions are set up to branch into an empty block
/// for further unrolling. When the loop is finished this will need to be mutated to
/// jump to the end of the loop instead.
fn unroll_loop_iteration(mut self) -> (BasicBlockId, ValueId) {
fn unroll_loop_iteration(mut self, loop_header_id: BasicBlockId) -> (BasicBlockId, ValueId) {
// Kick off the unrolling from the initial source block.
let mut next_blocks = self.unroll_loop_block();

while let Some(block) = next_blocks.pop() {
self.insert_block = block;
self.source_block = self.get_original_block(block);
self.encountered_loop_header |= loop_header_id == self.source_block;

if !self.visited_blocks.contains(&self.source_block) {
let mut blocks = self.unroll_loop_block();
Expand All @@ -1021,6 +1036,11 @@ impl<'f> LoopIteration<'f> {
.induction_value
.expect("Expected to find the induction variable by end of loop iteration");

assert!(
self.encountered_loop_header,
"expected to encounter loop header when visiting blocks"
);

(end_block, induction_value)
}

Expand Down
Loading