Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions compiler/noirc_frontend/src/attribute_order.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use fm::FileId;
use noirc_errors::Span;
use petgraph::prelude::{DiGraph, NodeIndex};
use rustc_hash::FxHashMap as HashMap;

use crate::{
ast::Expression,
hir::{comptime::Value, def_map::LocalModuleId},
node_interner::FuncId,
};

#[derive(Debug)]
pub struct AttributeGraph {
default_stage: NodeIndex,

order: DiGraph<FuncId, f32>,

indices: HashMap<FuncId, NodeIndex>,

modified_functions: std::collections::HashSet<FuncId>,
}

#[derive(Debug, Copy, Clone)]
pub(crate) struct AttributeContext {
// The file where generated items should be added
pub(crate) file: FileId,
// The module where generated items should be added
pub(crate) module: LocalModuleId,
// The file where the attribute is located
pub(crate) attribute_file: FileId,
// The module where the attribute is located
pub(crate) attribute_module: LocalModuleId,
}

pub(crate) type CollectedAttributes = Vec<(FuncId, Value, Vec<Expression>, AttributeContext, Span)>;

impl AttributeContext {
pub(crate) fn new(file: FileId, module: LocalModuleId) -> Self {
Self { file, module, attribute_file: file, attribute_module: module }
}
}

impl Default for AttributeGraph {
fn default() -> Self {
let mut order = DiGraph::default();
let mut indices = HashMap::default();

let default_stage = order.add_node(FuncId::dummy_id());
indices.insert(FuncId::dummy_id(), default_stage);

Self { default_stage, order, indices, modified_functions: Default::default() }
}
}

impl AttributeGraph {
pub fn get_or_insert(&mut self, attr: FuncId) -> NodeIndex {
if let Some(index) = self.indices.get(&attr) {
return *index;
}

let index = self.order.add_node(attr);
self.indices.insert(attr, index);
index
}

pub fn add_ordering_constraint(&mut self, run_first: FuncId, run_second: FuncId) {
let first_index = self.get_or_insert(run_first);
let second_index = self.get_or_insert(run_second);

// Just for debugging
if run_first != FuncId::dummy_id() {
self.modified_functions.insert(run_first);
self.modified_functions.insert(run_second);
}

self.order.update_edge(second_index, first_index, 1.0);
}

/// The default ordering of an attribute: run in the default stage
pub fn run_in_default_stage(&mut self, attr: FuncId) {
let index = self.get_or_insert(attr);
self.order.update_edge(self.default_stage, index, 1.0);
}

pub(crate) fn sort_attributes_by_run_order(&self, attributes: &mut CollectedAttributes) {
let topological_sort = petgraph::algo::toposort(&self.order, None).unwrap();

let ordering: HashMap<FuncId, usize> =
topological_sort.into_iter().map(|index| (self.order[index], index.index())).collect();

attributes.sort_by_key(|(f, ..)| ordering[f]);
}
}
152 changes: 105 additions & 47 deletions compiler/noirc_frontend/src/elaborator/comptime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use noirc_errors::{Location, Span};

use crate::{
ast::{Documented, Expression, ExpressionKind},
attribute_order::{AttributeContext, CollectedAttributes},
hir::{
comptime::{Interpreter, InterpreterError, Value},
def_collector::{
Expand All @@ -15,7 +16,7 @@ use crate::{
},
dc_mod,
},
def_map::{LocalModuleId, ModuleId},
def_map::ModuleId,
resolution::errors::ResolverError,
},
hir_def::expr::{HirExpression, HirIdent},
Expand All @@ -28,24 +29,6 @@ use crate::{

use super::{Elaborator, FunctionContext, ResolverMeta};

#[derive(Debug, Copy, Clone)]
struct AttributeContext {
// The file where generated items should be added
file: FileId,
// The module where generated items should be added
module: LocalModuleId,
// The file where the attribute is located
attribute_file: FileId,
// The module where the attribute is located
attribute_module: LocalModuleId,
}

impl AttributeContext {
fn new(file: FileId, module: LocalModuleId) -> Self {
Self { file, module, attribute_file: file, attribute_module: module }
}
}

impl<'context> Elaborator<'context> {
/// Elaborate an expression from the middle of a comptime scope.
/// When this happens we require additional information to know
Expand Down Expand Up @@ -131,57 +114,58 @@ impl<'context> Elaborator<'context> {
}
}

fn run_comptime_attributes_on_item(
fn collect_comptime_attributes_on_item(
&mut self,
attributes: &[SecondaryAttribute],
item: Value,
span: Span,
attribute_context: AttributeContext,
generated_items: &mut CollectedItems,
attributes_to_run: &mut CollectedAttributes,
) {
for attribute in attributes {
self.run_comptime_attribute_on_item(
self.collect_comptime_attribute_on_item(
attribute,
&item,
span,
attribute_context,
generated_items,
attributes_to_run,
);
}
}

fn run_comptime_attribute_on_item(
fn collect_comptime_attribute_on_item(
&mut self,
attribute: &SecondaryAttribute,
item: &Value,
span: Span,
attribute_context: AttributeContext,
generated_items: &mut CollectedItems,
attributes_to_run: &mut CollectedAttributes,
) {
if let SecondaryAttribute::Meta(attribute) = attribute {
self.elaborate_in_comptime_context(|this| {
if let Err(error) = this.run_comptime_attribute_name_on_item(
if let Err(error) = this.collect_comptime_attribute_name_on_item(
&attribute.contents,
item.clone(),
span,
attribute.contents_span,
attribute_context,
generated_items,
attributes_to_run,
) {
this.errors.push(error);
}
});
}
}

fn run_comptime_attribute_name_on_item(
/// Resolve an attribute to the function it refers to and add it to `attributes_to_run`
fn collect_comptime_attribute_name_on_item(
&mut self,
attribute: &str,
item: Value,
span: Span,
attribute_span: Span,
attribute_context: AttributeContext,
generated_items: &mut CollectedItems,
attributes_to_run: &mut CollectedAttributes,
) -> Result<(), (CompilationError, FileId)> {
self.file = attribute_context.attribute_file;
self.local_module = attribute_context.attribute_module;
Expand Down Expand Up @@ -233,6 +217,19 @@ impl<'context> Elaborator<'context> {
return Err((ResolverError::NonFunctionInAnnotation { span }.into(), self.file));
};

attributes_to_run.push((function, item, arguments, attribute_context, span));
Ok(())
}

fn run_attribute(
&mut self,
attribute_context: AttributeContext,
function: FuncId,
arguments: Vec<Expression>,
item: Value,
location: Location,
generated_items: &mut CollectedItems,
) -> Result<(), (CompilationError, FileId)> {
self.file = attribute_context.file;
self.local_module = attribute_context.module;

Expand All @@ -244,10 +241,7 @@ impl<'context> Elaborator<'context> {
arguments,
location,
)
.map_err(|error| {
let file = error.get_location().file;
(error.into(), file)
})?;
.map_err(|error| error.into_compilation_error_pair())?;

arguments.insert(0, (item, location));

Expand Down Expand Up @@ -537,19 +531,19 @@ impl<'context> Elaborator<'context> {
functions: &[UnresolvedFunctions],
module_attributes: &[ModuleAttribute],
) -> CollectedItems {
let mut generated_items = CollectedItems::default();
let mut attributes_to_run = Vec::new();

for (trait_id, trait_) in traits {
let attributes = &trait_.trait_def.attributes;
let item = Value::TraitDefinition(*trait_id);
let span = trait_.trait_def.span;
let context = AttributeContext::new(trait_.file_id, trait_.module_id);
self.run_comptime_attributes_on_item(
self.collect_comptime_attributes_on_item(
attributes,
item,
span,
context,
&mut generated_items,
&mut attributes_to_run,
);
}

Expand All @@ -558,26 +552,38 @@ impl<'context> Elaborator<'context> {
let item = Value::StructDefinition(*struct_id);
let span = struct_def.struct_def.span;
let context = AttributeContext::new(struct_def.file_id, struct_def.module_id);
self.run_comptime_attributes_on_item(
self.collect_comptime_attributes_on_item(
attributes,
item,
span,
context,
&mut generated_items,
&mut attributes_to_run,
);
}

self.run_attributes_on_functions(functions, &mut generated_items);
self.collect_attributes_on_functions(functions, &mut attributes_to_run);
self.collect_attributes_on_modules(module_attributes, &mut attributes_to_run);

self.run_attributes_on_modules(module_attributes, &mut generated_items);
self.interner.attribute_order.sort_attributes_by_run_order(&mut attributes_to_run);

// run
let mut generated_items = CollectedItems::default();
for (attribute, item, args, context, span) in attributes_to_run {
let location = Location::new(span, context.attribute_file);
if let Err(error) =
self.run_attribute(context, attribute, args, item, location, &mut generated_items)
{
self.errors.push(error);
}
}

generated_items
}

fn run_attributes_on_modules(
fn collect_attributes_on_modules(
&mut self,
module_attributes: &[ModuleAttribute],
generated_items: &mut CollectedItems,
attributes_to_run: &mut CollectedAttributes,
) {
for module_attribute in module_attributes {
let local_id = module_attribute.module_id;
Expand All @@ -593,14 +599,20 @@ impl<'context> Elaborator<'context> {
attribute_module: module_attribute.attribute_module_id,
};

self.run_comptime_attribute_on_item(attribute, &item, span, context, generated_items);
self.collect_comptime_attribute_on_item(
attribute,
&item,
span,
context,
attributes_to_run,
);
}
}

fn run_attributes_on_functions(
fn collect_attributes_on_functions(
&mut self,
function_sets: &[UnresolvedFunctions],
generated_items: &mut CollectedItems,
attributes_to_run: &mut CollectedAttributes,
) {
for function_set in function_sets {
self.self_type = function_set.self_type.clone();
Expand All @@ -610,12 +622,12 @@ impl<'context> Elaborator<'context> {
let attributes = function.secondary_attributes();
let item = Value::FunctionDefinition(*function_id);
let span = function.span();
self.run_comptime_attributes_on_item(
self.collect_comptime_attributes_on_item(
attributes,
item,
span,
context,
generated_items,
attributes_to_run,
);
}
}
Expand Down Expand Up @@ -650,4 +662,50 @@ impl<'context> Elaborator<'context> {
_ => false,
}
}

pub(super) fn register_attribute_order(
&mut self,
id: FuncId,
attributes: &[SecondaryAttribute],
) {
let mut has_order = false;

for attribute in attributes {
let (name, run_before) = match attribute {
SecondaryAttribute::RunBefore(name) => (name, true),
SecondaryAttribute::RunAfter(name) => (name, false),
_ => continue,
};

// Parse a path from #[run_before(path)]
let Some(path) = Parser::for_str(name).parse_path_no_turbofish() else {
todo!("function should be a path")
};

let definition_id = self.resolve_variable(path).id;

match self.interner.definition(definition_id).kind {
DefinitionKind::Function(attribute_arg) => {
if attribute_arg == id {
todo!("Attribute cannot be run before or after itself");
}

has_order = true;
if run_before {
self.interner.attribute_order.add_ordering_constraint(attribute_arg, id);
} else {
self.interner.attribute_order.add_ordering_constraint(id, attribute_arg);
}
}
_ => {
todo!("path doesn't refer to a function")
}
}
}

// If no ordering was specified, set it to the default stage.
if !has_order {
self.interner.attribute_order.run_in_default_stage(id);
}
}
}
Loading