Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions compiler/noirc_evaluator/src/ssa/opt/inline_simple_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use crate::ssa::{
call_graph::CallGraph,
function::{Function, RuntimeType},
},
opt::brillig_entry_points::get_brillig_entry_points,
ssa_gen::Ssa,
};

Expand All @@ -33,13 +34,25 @@ impl Ssa {
pub(crate) fn inline_simple_functions(mut self: Ssa) -> Result<Ssa, RuntimeError> {
let call_graph = CallGraph::from_ssa(&self);
let recursive_functions = call_graph.get_recursive_functions();
let brillig_entry_points =
get_brillig_entry_points(&self.functions, self.main_id, &call_graph);

let should_inline_call = |callee: &Function| {
let runtime = callee.runtime();
if let RuntimeType::Acir(_) = callee.runtime() {
// Functions marked to not have predicates should be preserved.
if callee.is_no_predicates() {
return false;
}
// ACIR entry points (e.g., foldable functions) should be preserved.
if runtime.is_entry_point() {
return false;
}
}

// Do not inline Brillig entry points
if brillig_entry_points.contains_key(&callee.id()) {
return false;
}

let entry_block_id = callee.entry_block();
Expand Down Expand Up @@ -374,4 +387,46 @@ mod test {
";
assert_does_not_inline(src);
}

#[test]
fn basic_inlining_brillig_not_inlined_into_acir() {
// We expect that Brillig entry points (e.g., Brillig functions called from ACIR) should never be inlined.
let src = "
acir(inline) fn foo f0 {
b0():
v1 = call f1() -> Field
return v1
}
brillig(inline) fn bar f1 {
b0():
return Field 72
}
";
assert_does_not_inline(src);
}

#[test]
fn does_not_inline_acir_fold_functions() {
let src = "
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v3 = call f1(v0, v1) -> Field
v4 = call f1(v0, v1) -> Field
v5 = call f1(v0, v1) -> Field
v6 = eq v3, v4
constrain v3 == v4
v7 = eq v4, v5
constrain v4 == v5
return
}
acir(fold) fn foo f1 {
b0(v0: Field, v1: Field):
v2 = eq v0, v1
v3 = not v2
constrain v2 == u1 0
return v0
}
";
assert_does_not_inline(src);
}
}
214 changes: 132 additions & 82 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::ssa::{
basic_block::BasicBlockId,
call_graph::CallGraph,
dfg::InsertInstructionResult,
function::{Function, FunctionId, RuntimeType},
function::{Function, FunctionId},
instruction::{Instruction, InstructionId, TerminatorInstruction},
value::{Value, ValueId},
},
Expand Down Expand Up @@ -79,8 +79,7 @@ impl Ssa {
inline_no_predicates_functions,
aggressiveness,
);
self =
Self::inline_functions_inner(self, &inline_infos, inline_no_predicates_functions)?;
self = Self::inline_functions_inner(self, &inline_infos)?;

let num_functions_after = self.functions.len();
if num_functions_after == num_functions_before {
Expand All @@ -91,30 +90,15 @@ impl Ssa {
Ok(self)
}

fn inline_functions_inner(
mut self,
inline_infos: &InlineInfos,
inline_no_predicates_functions: bool,
) -> Result<Ssa, RuntimeError> {
fn inline_functions_inner(mut self, inline_infos: &InlineInfos) -> Result<Ssa, RuntimeError> {
let inline_targets = inline_infos.iter().filter_map(|(id, info)| {
let dfg = &self.functions[id].dfg;
info.is_inline_target(dfg).then_some(*id)
});

let should_inline_call = |callee: &Function| -> bool {
match callee.runtime() {
RuntimeType::Acir(_) => {
// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && callee.is_no_predicates();
!preserve_function
}
RuntimeType::Brillig(_) => {
// We inline inline if the function called wasn't ruled out as too costly or recursive.
InlineInfo::should_inline(inline_infos, callee.id())
}
}
// We defer to the inline info computation to determine whether a function should be inlined
InlineInfo::should_inline(inline_infos, callee.id())
};

// NOTE: Functions are processed independently of each other, with the final mapping replacing the original,
Expand Down Expand Up @@ -464,29 +448,30 @@ impl<'function> PerFunctionContext<'function> {
Instruction::Call { func, arguments } => match self.get_function(*func) {
Some(func_id) => {
let call_stack = self.source_function.dfg.get_instruction_call_stack(*id);
if let Some(callee) = self.should_inline_call(ssa, func_id, call_stack)? {
if should_inline_call(callee) {
self.inline_function(
ssa,
*id,
func_id,
arguments,
should_inline_call,
)?;

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
// resulting in predicates not being applied properly.
//
// Note that this doesn't cover the case in which there exists an `Instruction::EnableSideEffectsIf`
// within the function being inlined whilst the source function has not encountered one yet.
// In practice this isn't an issue as the last `Instruction::EnableSideEffectsIf` in the
// function being inlined will be to turn off predicates rather than to create one.
if let Some(condition) = side_effects_enabled {
self.context.builder.insert_enable_side_effects_if(condition);
}
} else {
self.push_instruction(*id);
let callee = &ssa.functions[&func_id];

// Sanity check to validate runtime compatibility
self.validate_callee(callee, call_stack)?;

// Do not inline self-recursive functions on the top level.
// Inlining a self-recursive function works when there is something to inline into
// by importing all the recursive blocks, but for the entry function there is no wrapper.
// We must do this check here as inlining can be can triggered on a non-inline target (e.g., non-entry point).
let inlining_self_recursion_at_top_level =
self.entry_function.id() == func_id;
if !inlining_self_recursion_at_top_level && should_inline_call(callee) {
self.inline_function(ssa, *id, func_id, arguments, should_inline_call)?;

// This is only relevant during handling functions with `InlineType::NoPredicates` as these
// can pollute the function they're being inlined into with `Instruction::EnabledSideEffects`,
// resulting in predicates not being applied properly.
//
// Note that this doesn't cover the case in which there exists an `Instruction::EnableSideEffectsIf`
// within the function being inlined whilst the source function has not encountered one yet.
// In practice this isn't an issue as the last `Instruction::EnableSideEffectsIf` in the
// function being inlined will be to turn off predicates rather than to create one.
if let Some(condition) = side_effects_enabled {
self.context.builder.insert_enable_side_effects_if(condition);
}
} else {
self.push_instruction(*id);
Expand All @@ -504,46 +489,33 @@ impl<'function> PerFunctionContext<'function> {
Ok(())
}

fn should_inline_call<'a>(
/// Extra error check where given a caller's runtime its callee runtime is valid.
/// We determine validity as the following (where we have caller -> callee).
/// Valid:
/// - ACIR -> ACIR
/// - ACIR -> Brillig
/// - Brillig -> Brillig
///
/// Invalid:
/// - Brillig -> ACIR
///
/// Whether a valid callee should be inlined is determined separately by the inline info computation.
fn validate_callee(
&self,
ssa: &'a Ssa,
called_func_id: FunctionId,
callee: &Function,
call_stack: Vec<Location>,
) -> Result<Option<&'a Function>, RuntimeError> {
// Do not inline self-recursive functions on the top level.
// Inlining a self-recursive function works when there is something to inline into
// by importing all the recursive blocks, but for the entry function there is no wrapper.
if self.entry_function.id() == called_func_id {
return Ok(None);
}

let callee = &ssa.functions[&called_func_id];

match callee.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point
// If it is called from brillig, it cannot be inlined because runtimes do not share the same semantic
if self.entry_function.runtime().is_brillig() {
return Err(RuntimeError::UnconstrainedCallingConstrained {
call_stack,
constrained: callee.name().to_string(),
unconstrained: self.entry_function.name().to_string(),
});
}
if inline_type.is_entry_point() {
assert!(!self.entry_function.runtime().is_brillig());
return Ok(None);
}
}
RuntimeType::Brillig(_) => {
if self.entry_function.runtime().is_acir() {
// We never inline a brillig function into an ACIR function.
return Ok(None);
}
}
) -> Result<(), RuntimeError> {
if self.entry_function.runtime().is_brillig() && callee.runtime().is_acir() {
// If the caller is Brillig and the called function is ACIR,
// it cannot be inlined because runtimes do not share the same semantics
return Err(RuntimeError::UnconstrainedCallingConstrained {
call_stack,
constrained: callee.name().to_string(),
unconstrained: self.entry_function.name().to_string(),
});
}

Ok(Some(callee))
Ok(())
}

/// Inline a function call and remember the inlined return values in the values map
Expand Down Expand Up @@ -712,7 +684,11 @@ mod test {
use crate::{
assert_ssa_snapshot,
errors::RuntimeError,
ssa::{Ssa, ir::instruction::TerminatorInstruction, opt::assert_normalized_ssa_equals},
ssa::{
Ssa,
ir::{instruction::TerminatorInstruction, map::Id},
opt::assert_normalized_ssa_equals,
},
};

#[test]
Expand Down Expand Up @@ -807,7 +783,6 @@ mod test {
v2 = call f1(u32 5) -> u32
return v2
}

acir(inline) fn factorial f1 {
b0(v1: u32):
v2 = lt v1, u32 1
Expand Down Expand Up @@ -862,6 +837,55 @@ mod test {
");
}

/// This test is the same as [recursive_functions] we just want to test that inlining
/// does not fail when triggered from the self recursive non-entry point function instead
/// of the program entry point.
#[test]
fn recursive_functions_non_inline_target() {
let src = "
acir(inline) fn main f0 {
b0():
v2 = call f1(u32 5) -> u32
return v2
}
acir(inline) fn factorial f1 {
b0(v1: u32):
v2 = lt v1, u32 1
jmpif v2 then: b1, else: b2
b1():
jmp b3(u32 1)
b2():
v4 = sub v1, u32 1
v5 = call f1(v4) -> u32
v6 = mul v1, v5
jmp b3(v6)
b3(v7: u32):
return v7
}
";
let ssa = Ssa::from_str(src).unwrap();
let f1 = &ssa.functions[&Id::test_new(1)];
let function = f1.inlined(&ssa, &|_| true).unwrap();
// The expected string must be formatted this way as to account for newlines and whitespace
assert_eq!(
function.to_string(),
"acir(inline) fn factorial f1 {
b0(v0: u32):
v3 = eq v0, u32 0
jmpif v3 then: b1, else: b2
b1():
jmp b3(u32 1)
b2():
v5 = sub v0, u32 1
v7 = call f1(v5) -> u32
v8 = mul v0, v7
jmp b3(v8)
b3(v4: u32):
return v4
}"
);
}

#[test]
fn displaced_return_mapping() {
// This test is designed specifically to catch a regression in which the ids of blocks
Expand Down Expand Up @@ -1334,4 +1358,30 @@ mod test {
panic!("Expected inlining to fail with RuntimeError::UnconstrainedCallingConstrained");
}
}

#[test]
fn does_not_inline_acir_fold_functions() {
let src = "
acir(inline) fn main f0 {
b0(v0: Field, v1: Field):
v3 = call f1(v0, v1) -> Field
v4 = call f1(v0, v1) -> Field
v5 = call f1(v0, v1) -> Field
v6 = eq v3, v4
constrain v3 == v4
v7 = eq v4, v5
constrain v4 == v5
return
}
acir(fold) fn foo f1 {
b0(v0: Field, v1: Field):
v2 = eq v0, v1
v3 = not v2
constrain v2 == u1 0
return v0
}
";
let ssa = Ssa::from_str(src).unwrap();
assert_normalized_ssa_equals(ssa, src);
}
}
Loading
Loading