Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions compiler/noirc_evaluator/src/ssa/ir/call_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,75 @@ mod tests {
assert!(recursive_functions.contains(&Id::test_new(3)));
}

#[test]
fn mark_multiple_independent_recursion_cycles() {
// This test is an expanded version of `mark_mutually_recursive_functions` where we have multiple recursive cycles.
let src = "
acir(inline) fn main f0 {
b0():
call f1()
call f4()
return
}
// First recursive cycle: f1 -> f2 -> f3 -> f1
brillig(inline) fn starter f1 {
b0():
call f2()
return
}
brillig(inline) fn ping f2 {
b0():
call f3()
return
}
brillig(inline) fn pong f3 {
b0():
call f1()
return
}
// Second recursive cycle: f4 <-> f5
brillig(inline) fn foo f4 {
b0():
call f5()
return
}
brillig(inline) fn bar f5 {
b0():
call f4()
return
}
// Non-recursive leaf function
brillig(inline) fn baz f6 {
b0():
return
}
";

let ssa = Ssa::from_str(src).unwrap();
let call_graph = CallGraph::from_ssa_weighted(&ssa);
let recursive_functions = call_graph.get_recursive_functions();

// There should be 5 recursive functions: f1, f2, f3 (cycle 1), and f4, f5 (cycle 2)
let expected_recursive_ids = [1, 2, 3, 4, 5].map(Id::test_new).to_vec();

assert_eq!(
recursive_functions.len(),
expected_recursive_ids.len(),
"Expected {} recursive functions",
expected_recursive_ids.len()
);

for func_id in expected_recursive_ids {
assert!(
recursive_functions.contains(&func_id),
"Function {func_id} should be marked recursive",
);
}

// f6 should not be marked recursive
assert!(!recursive_functions.contains(&Id::test_new(6)));
}

#[test]
fn mark_self_recursive_function() {
let src = "
Expand All @@ -368,6 +437,82 @@ mod tests {
assert!(recursive_functions.contains(&Id::test_new(1)));
}

#[test]
fn self_recursive_and_calls_others() {
let src = "
acir(inline) fn main f0 {
b0():
call f1()
return
}
brillig(inline) fn self_recur f1 {
b0():
call f1()
call f2()
return
}
brillig(inline) fn foo f2 {
b0():
return
}
";
let ssa = Ssa::from_str(src).unwrap();
let call_graph = CallGraph::from_ssa_weighted(&ssa);

let f0 = Id::test_new(0);
let f1 = Id::test_new(1);
let f2 = Id::test_new(2);

let recursive = call_graph.get_recursive_functions();
assert!(recursive.contains(&f1));
assert!(!recursive.contains(&f0));
assert!(!recursive.contains(&f2));

let callees = call_graph.callees();
let f1_callees = callees.get(&f1).unwrap();
assert_eq!(f1_callees.len(), 2);
assert_eq!(*f1_callees.get(&f1).unwrap(), 1, "f1 should call itself once");
assert_eq!(*f1_callees.get(&f2).unwrap(), 1, "f1 should call f2 once");

let callers = call_graph.callers();
let f1_callers = callers.get(&f1).unwrap();
assert_eq!(f1_callers.len(), 2);
assert_eq!(*f1_callers.get(&f0).unwrap(), 1, "f0 calls f1 once");
assert_eq!(*f1_callers.get(&f1).unwrap(), 1, "f1 calls itself once");

let f2_callers = callers.get(&f2).unwrap();
assert_eq!(f2_callers.len(), 1);
assert_eq!(*f2_callers.get(&f1).unwrap(), 1, "f1 calls f2 once");

let f2_callees = callees.get(&f2).unwrap();
assert!(f2_callees.is_empty(), "f2 should not call any functions");

let f0_callees = callees.get(&f0).unwrap();
assert_eq!(f0_callees.len(), 1);
assert_eq!(*f0_callees.get(&f1).unwrap(), 1);
}

#[test]
fn pure_self_recursive_function() {
let src = "
brillig(inline) fn self_recur f0 {
b0():
call f0()
return
}
";
let ssa = Ssa::from_str(src).unwrap();
let call_graph = CallGraph::from_ssa_weighted(&ssa);

let recursive = call_graph.get_recursive_functions();
assert!(recursive.contains(&Id::test_new(0)));

let callers = call_graph.callers();
let f0_callers = callers.get(&Id::test_new(0)).unwrap();
assert_eq!(f0_callers.len(), 1);
assert_eq!(*f0_callers.get(&Id::test_new(0)).unwrap(), 1);
}

fn callers_and_callees_src() -> &'static str {
r#"
acir(inline) fn main f0 {
Expand Down Expand Up @@ -506,4 +651,26 @@ mod tests {
*times_called.get(&Id::test_new(4)).expect(" Should have times called");
assert_eq!(times_f4_called, 2);
}

#[test]
fn dead_function_not_called() {
let src = "
acir(inline) fn main f0 {
b0():
return
}
brillig(inline) fn dead_code f1 {
b0():
return
}
";
let ssa = Ssa::from_str(src).unwrap();
let call_graph = CallGraph::from_ssa_weighted(&ssa);

// f1 is never called, but it should still be tracked.
let times_called = call_graph.times_called();
assert_eq!(*times_called.get(&Id::test_new(1)).unwrap(), 0);
assert!(call_graph.callers().get(&Id::test_new(1)).unwrap().is_empty());
assert!(call_graph.callees().get(&Id::test_new(1)).unwrap().is_empty());
}
}
120 changes: 120 additions & 0 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,26 @@ mod test {
");
}

#[test]
fn basic_inlining_brillig_not_inlined_into_acir() {
// This test matches the `basic_inlining` test exactly except that f1 is marked as a Brillig runtime.
// We expect that Brillig entry points (e.g., Brillig functions called from ACIR) should never be inlined.
let src = "
acir(inline) fn foo f0 {
b0():
v1 = call f1() -> Field
return v1
}
brillig(inline) fn bar f1 {
b0():
return Field 72
}
";
let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.inline_functions(i64::MAX).unwrap();
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn complex_inlining() {
// This SSA is from issue #1327 which previously failed to inline properly
Expand Down Expand Up @@ -1030,4 +1050,104 @@ mod test {
}
");
}

#[test]
fn static_assertions_to_always_be_inlined() {
let src = "
brillig(inline) fn main f0 {
b0():
call f1(Field 1)
return
}
brillig(inline) fn foo f1 {
b0(v0: Field):
call assert_constant(v0)
return
}
";
let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.inline_functions(i64::MAX).unwrap();

assert_ssa_snapshot!(ssa, @r"
brillig(inline) fn main f0 {
b0():
return
}
");
}

#[test]
fn no_predicates_flag_inactive() {
let src = "
acir(inline) fn main f0 {
b0():
call f1()
return
}
acir(no_predicates) fn no_predicates f1 {
b0():
return
}
";

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.inline_functions(i64::MAX).unwrap();
assert_normalized_ssa_equals(ssa, src);
}

#[test]
fn no_predicates_flag_active() {
let src = "
acir(inline) fn main f0 {
b0():
call f1()
return
}
acir(no_predicates) fn no_predicates f1 {
b0():
return
}
";

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.inline_functions_with_no_predicates(i64::MAX).unwrap();

assert_ssa_snapshot!(ssa, @r"
acir(inline) fn main f0 {
b0():
return
}
");
}

#[test]
fn inline_always_function() {
let src = "
brillig(inline) fn main f0 {
b0():
call f1()
return
}

brillig(inline_always) fn always_inline f1 {
b0():
return
}
";
let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.inline_functions(i64::MIN).unwrap();
assert_ssa_snapshot!(ssa, @r"
brillig(inline) fn main f0 {
b0():
return
}
");

// Check that with a minimum inliner aggressiveness we do not inline a function
// not marked with `inline_always`
let no_inline_always_src = &src.replace("inline_always", "inline");
let ssa = Ssa::from_str(no_inline_always_src).unwrap();
let ssa = ssa.inline_functions(i64::MIN).unwrap();
assert_normalized_ssa_equals(ssa, no_inline_always_src);
}
}
Loading
Loading