diff --git a/compiler/noirc_driver/src/abi_gen.rs b/compiler/noirc_driver/src/abi_gen.rs index 59b3faf1a4e..9838a8af210 100644 --- a/compiler/noirc_driver/src/abi_gen.rs +++ b/compiler/noirc_driver/src/abi_gen.rs @@ -112,7 +112,7 @@ pub(super) fn abi_type_from_hir_type(context: &Context, typ: &Type) -> AbiType { Type::DataType(def, args) => { let struct_type = def.borrow(); - let fields = struct_type.get_fields(args); + let fields = struct_type.get_fields(args).unwrap_or_default(); let fields = vecmap(fields, |(name, typ)| (name, abi_type_from_hir_type(context, &typ))); // For the ABI, we always want to resolve the struct paths from the root crate diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index be5cde1e0ea..9b0172853c0 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -551,9 +551,10 @@ fn compile_contract_inner( .map(|struct_id| { let typ = context.def_interner.get_type(struct_id); let typ = typ.borrow(); - let fields = vecmap(typ.get_fields(&[]), |(name, typ)| { - (name, abi_type_from_hir_type(context, &typ)) - }); + let fields = + vecmap(typ.get_fields(&[]).unwrap_or_default(), |(name, typ)| { + (name, abi_type_from_hir_type(context, &typ)) + }); let path = context.fully_qualified_struct_path(context.root_crate_id(), typ.id); AbiType::Struct { path, fields } diff --git a/compiler/noirc_frontend/src/elaborator/enums.rs b/compiler/noirc_frontend/src/elaborator/enums.rs index f574aa81619..2ccd2b25561 100644 --- a/compiler/noirc_frontend/src/elaborator/enums.rs +++ b/compiler/noirc_frontend/src/elaborator/enums.rs @@ -70,6 +70,7 @@ impl Elaborator<'_> { type_id: Some(type_id), trait_id: None, trait_impl: None, + enum_variant_index: Some(variant_index), is_entry_point: false, has_inline_attribute: false, function_body: FunctionBody::Resolved, diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 650bd579f2f..68e13688b1c 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -622,7 +622,9 @@ impl<'context> Elaborator<'context> { (typ, generics) } else { match self.lookup_type_or_error(path) { - Some(Type::DataType(r#type, struct_generics)) => (r#type, struct_generics), + Some(Type::DataType(r#type, struct_generics)) if r#type.borrow().is_struct() => { + (r#type, struct_generics) + } Some(typ) => { self.push_err(ResolverError::NonStructUsedInConstructor { typ: typ.to_string(), @@ -649,7 +651,11 @@ impl<'context> Elaborator<'context> { let generics = struct_generics.clone(); let fields = constructor.fields; - let field_types = r#type.borrow().get_fields_with_visibility(&struct_generics); + let field_types = r#type + .borrow() + .get_fields_with_visibility(&struct_generics) + .expect("This type should already be validated to be a struct"); + let fields = self.resolve_constructor_expr_fields(struct_type.clone(), field_types, fields, span); let expr = HirExpression::Constructor(HirConstructorExpression { @@ -660,7 +666,7 @@ impl<'context> Elaborator<'context> { let struct_id = struct_type.borrow().id; let reference_location = Location::new(last_segment.ident.span(), self.file); - self.interner.add_struct_reference(struct_id, reference_location, is_self_type); + self.interner.add_type_reference(struct_id, reference_location, is_self_type); (expr, Type::DataType(struct_type, generics)) } @@ -683,7 +689,10 @@ impl<'context> Elaborator<'context> { ) -> Vec<(Ident, ExprId)> { let mut ret = Vec::with_capacity(fields.len()); let mut seen_fields = HashSet::default(); - let mut unseen_fields = struct_type.borrow().field_names(); + let mut unseen_fields = struct_type + .borrow() + .field_names() + .expect("This type should already be validated to be a struct"); for (field_name, field) in fields { let expected_field_with_index = field_types diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 24b0bff457a..849552faf79 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -12,13 +12,11 @@ use crate::{ graph::CrateId, hir::{ def_collector::dc_crate::{ - filter_literal_globals, CompilationError, ImplMap, UnresolvedFunctions, - UnresolvedGlobal, UnresolvedStruct, UnresolvedTraitImpl, UnresolvedTypeAlias, - }, - def_collector::{ - dc_crate::{CollectedItems, UnresolvedEnum}, - errors::DefCollectorErrorKind, + filter_literal_globals, CollectedItems, CompilationError, ImplMap, UnresolvedEnum, + UnresolvedFunctions, UnresolvedGlobal, UnresolvedStruct, UnresolvedTraitImpl, + UnresolvedTypeAlias, }, + def_collector::errors::DefCollectorErrorKind, def_map::{DefMaps, ModuleData}, def_map::{LocalModuleId, ModuleId, MAIN_FUNCTION}, resolution::errors::ResolverError, @@ -26,11 +24,10 @@ use crate::{ type_check::{generics::TraitGenerics, TypeCheckError}, Context, }, - hir_def::traits::TraitImpl, hir_def::{ expr::{HirCapturedVar, HirIdent}, function::{FuncMeta, FunctionBody, HirFunction}, - traits::TraitConstraint, + traits::{TraitConstraint, TraitImpl}, types::{Generics, Kind, ResolvedGeneric}, }, node_interner::{ @@ -1006,6 +1003,7 @@ impl<'context> Elaborator<'context> { type_id: struct_id, trait_id, trait_impl: self.current_trait_impl, + enum_variant_index: None, parameters: parameters.into(), parameter_idents, return_type: func.def.return_type.clone(), @@ -1035,13 +1033,21 @@ impl<'context> Elaborator<'context> { self.mark_type_as_used(typ); } } - Type::DataType(struct_type, generics) => { - self.mark_struct_as_constructed(struct_type.clone()); + Type::DataType(datatype, generics) => { + self.mark_struct_as_constructed(datatype.clone()); for generic in generics { self.mark_type_as_used(generic); } - for (_, typ) in struct_type.borrow().get_fields(generics) { - self.mark_type_as_used(&typ); + if let Some(fields) = datatype.borrow().get_fields(generics) { + for (_, typ) in fields { + self.mark_type_as_used(&typ); + } + } else if let Some(variants) = datatype.borrow().get_variants(generics) { + for (_, variant_types) in variants { + for typ in variant_types { + self.mark_type_as_used(&typ); + } + } } } Type::Alias(alias_type, generics) => { @@ -1776,7 +1782,7 @@ impl<'context> Elaborator<'context> { // Only handle structs without generics as any generics args will be checked // after monomorphization when performing SSA codegen if struct_type.borrow().generics.is_empty() { - let fields = struct_type.borrow().get_fields(&[]); + let fields = struct_type.borrow().get_fields(&[]).unwrap(); for (_, field_type) in fields.iter() { if field_type.is_nested_slice() { let location = struct_type.borrow().location; @@ -1831,6 +1837,9 @@ impl<'context> Elaborator<'context> { span: typ.enum_def.span, }; + datatype.borrow_mut().init_variants(); + let module_id = ModuleId { krate: self.crate_id, local_id: typ.module_id }; + for (i, variant) in typ.enum_def.variants.iter().enumerate() { let types = vecmap(&variant.item.parameters, |typ| self.resolve_type(typ.clone())); let name = variant.item.name.clone(); @@ -1848,7 +1857,8 @@ impl<'context> Elaborator<'context> { unresolved.clone(), ); - self.interner.add_definition_location(ReferenceId::EnumVariant(*type_id, i), None); + let reference_id = ReferenceId::EnumVariant(*type_id, i); + self.interner.add_definition_location(reference_id, Some(module_id)); } } } diff --git a/compiler/noirc_frontend/src/elaborator/patterns.rs b/compiler/noirc_frontend/src/elaborator/patterns.rs index 6ab12d1e537..eab0b91f0f6 100644 --- a/compiler/noirc_frontend/src/elaborator/patterns.rs +++ b/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -192,7 +192,11 @@ impl<'context> Elaborator<'context> { }; let (struct_type, generics) = match self.lookup_type_or_error(name) { - Some(Type::DataType(struct_type, struct_generics)) => (struct_type, struct_generics), + Some(Type::DataType(struct_type, struct_generics)) + if struct_type.borrow().is_struct() => + { + (struct_type, struct_generics) + } None => return error_identifier(self), Some(typ) => { let typ = typ.to_string(); @@ -234,7 +238,7 @@ impl<'context> Elaborator<'context> { let struct_id = struct_type.borrow().id; let reference_location = Location::new(name_span, self.file); - self.interner.add_struct_reference(struct_id, reference_location, is_self_type); + self.interner.add_type_reference(struct_id, reference_location, is_self_type); for (field_index, field) in fields.iter().enumerate() { let reference_location = Location::new(field.0.span(), self.file); @@ -260,7 +264,10 @@ impl<'context> Elaborator<'context> { ) -> Vec<(Ident, HirPattern)> { let mut ret = Vec::with_capacity(fields.len()); let mut seen_fields = HashSet::default(); - let mut unseen_fields = struct_type.borrow().field_names(); + let mut unseen_fields = struct_type + .borrow() + .field_names() + .expect("This type should already be validated to be a struct"); for (field, pattern) in fields { let (field_type, visibility) = expected_type diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 4d8a90d25d4..53d0860ebf1 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -158,7 +158,7 @@ impl<'context> Elaborator<'context> { // Record the location of the type reference self.interner.push_type_ref_location(resolved_type.clone(), location); if !is_synthetic { - self.interner.add_struct_reference( + self.interner.add_type_reference( data_type.borrow().id, location, is_self_type_name, @@ -1477,7 +1477,7 @@ impl<'context> Elaborator<'context> { let datatype = datatype.borrow(); let mut has_field_with_function_type = false; - if let Some(fields) = datatype.try_fields_raw() { + if let Some(fields) = datatype.fields_raw() { has_field_with_function_type = fields .iter() .any(|field| field.name.0.contents == method_name && field.typ.is_function()); diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs index c198c5b009c..d46484d05fa 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs @@ -5,8 +5,8 @@ use crate::ast::{ ArrayLiteral, AssignStatement, BlockExpression, CallExpression, CastExpression, ConstrainKind, ConstructorExpression, ExpressionKind, ForLoopStatement, ForRange, GenericTypeArgs, Ident, IfExpression, IndexExpression, InfixExpression, LValue, Lambda, Literal, - MemberAccessExpression, MethodCallExpression, Path, PathSegment, Pattern, PrefixExpression, - UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, + MemberAccessExpression, MethodCallExpression, Path, PathKind, PathSegment, Pattern, + PrefixExpression, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, }; use crate::ast::{ConstrainStatement, Expression, Statement, StatementKind}; use crate::hir_def::expr::{ @@ -217,7 +217,9 @@ impl HirExpression { HirExpression::EnumConstructor(constructor) => { let typ = constructor.r#type.borrow(); let variant = &typ.variant_at(constructor.variant_index); - let path = Path::from_single(variant.name.to_string(), span); + let segment1 = PathSegment { ident: typ.name.clone(), span, generics: None }; + let segment2 = PathSegment { ident: variant.name.clone(), span, generics: None }; + let path = Path { segments: vec![segment1, segment2], kind: PathKind::Plain, span }; let func = Box::new(Expression::new(ExpressionKind::Variable(path), span)); let arguments = vecmap(&constructor.arguments, |arg| arg.to_display_ast(interner)); let call = CallExpression { func, arguments, is_macro_call: false }; diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index f98f9c5954c..6503b0cf77b 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -569,9 +569,11 @@ fn struct_def_fields( let mut fields = im::Vector::new(); - for (field_name, field_type) in struct_def.get_fields(&generic_args) { - let name = Value::Quoted(Rc::new(vec![Token::Ident(field_name)])); - fields.push_back(Value::Tuple(vec![name, Value::Type(field_type)])); + if let Some(struct_fields) = struct_def.get_fields(&generic_args) { + for (field_name, field_type) in struct_fields { + let name = Value::Quoted(Rc::new(vec![Token::Ident(field_name)])); + fields.push_back(Value::Tuple(vec![name, Value::Type(field_type)])); + } } let typ = Type::Slice(Box::new(Type::Tuple(vec![ @@ -597,10 +599,12 @@ fn struct_def_fields_as_written( let mut fields = im::Vector::new(); - for field in struct_def.get_fields_as_written() { - let name = Value::Quoted(Rc::new(vec![Token::Ident(field.name.to_string())])); - let typ = Value::Type(field.typ); - fields.push_back(Value::Tuple(vec![name, typ])); + if let Some(struct_fields) = struct_def.get_fields_as_written() { + for field in struct_fields { + let name = Value::Quoted(Rc::new(vec![Token::Ident(field.name.to_string())])); + let typ = Value::Type(field.typ); + fields.push_back(Value::Tuple(vec![name, typ])); + } } let typ = Type::Slice(Box::new(Type::Tuple(vec![ @@ -1456,7 +1460,8 @@ fn zeroed(return_type: Type, span: Span) -> IResult { Type::Unit => Ok(Value::Unit), Type::Tuple(fields) => Ok(Value::Tuple(try_vecmap(fields, |field| zeroed(field, span))?)), Type::DataType(struct_type, generics) => { - let fields = struct_type.borrow().get_fields(&generics); + // TODO: Handle enums + let fields = struct_type.borrow().get_fields(&generics).unwrap(); let mut values = HashMap::default(); for (field_name, field_type) in fields { diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 440426c5e75..c77b3f07a7f 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -1072,7 +1072,7 @@ pub fn collect_struct( } }; - interner.set_doc_comments(ReferenceId::Struct(id), doc_comments); + interner.set_doc_comments(ReferenceId::Type(id), doc_comments); for (index, field) in unresolved.struct_def.fields.iter().enumerate() { if !field.doc_comments.is_empty() { @@ -1106,7 +1106,7 @@ pub fn collect_struct( } if interner.is_in_lsp_mode() { - interner.register_struct(id, name.to_string(), visibility, parent_module_id); + interner.register_type(id, name.to_string(), visibility, parent_module_id); } Some((id, unresolved)) @@ -1167,7 +1167,7 @@ pub fn collect_enum( } }; - interner.set_doc_comments(ReferenceId::Enum(id), doc_comments); + interner.set_doc_comments(ReferenceId::Type(id), doc_comments); for (index, variant) in unresolved.enum_def.variants.iter().enumerate() { if !variant.doc_comments.is_empty() { @@ -1201,7 +1201,7 @@ pub fn collect_enum( } if interner.is_in_lsp_mode() { - interner.register_enum(id, name.to_string(), visibility, parent_module_id); + interner.register_type(id, name.to_string(), visibility, parent_module_id); } Some((id, unresolved)) diff --git a/compiler/noirc_frontend/src/hir_def/function.rs b/compiler/noirc_frontend/src/hir_def/function.rs index 0e51539de91..75bb4f50541 100644 --- a/compiler/noirc_frontend/src/hir_def/function.rs +++ b/compiler/noirc_frontend/src/hir_def/function.rs @@ -141,6 +141,9 @@ pub struct FuncMeta { /// The trait impl this function belongs to, if any pub trait_impl: Option, + /// If this function is the one related to an enum variant, this holds its index (relative to `type_id`) + pub enum_variant_index: Option, + /// True if this function is an entry point to the program. /// For non-contracts, this means the function is `main`. pub is_entry_point: bool, diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index bab8f566f7c..1a9241b3b46 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -448,13 +448,19 @@ impl DataType { self.body = TypeBody::Struct(fields); } - pub(crate) fn push_variant(&mut self, variant: EnumVariant) { + pub(crate) fn init_variants(&mut self) { match &mut self.body { TypeBody::None => { - self.body = TypeBody::Enum(vec![variant]); + self.body = TypeBody::Enum(vec![]); } + _ => panic!("Called init_variants but body was None"), + } + } + + pub(crate) fn push_variant(&mut self, variant: EnumVariant) { + match &mut self.body { TypeBody::Enum(variants) => variants.push(variant), - TypeBody::Struct(_) => panic!("Called push_variant on a non-variant type {self}"), + _ => panic!("Called push_variant on {self} but body wasn't an enum"), } } @@ -464,32 +470,19 @@ impl DataType { /// Retrieve the fields of this type with no modifications. /// Returns None if this is not a struct type. - pub fn try_fields_raw(&self) -> Option<&[StructField]> { + pub fn fields_raw(&self) -> Option<&[StructField]> { match &self.body { TypeBody::Struct(fields) => Some(fields), _ => None, } } - /// Retrieve the fields of this type with no modifications. - /// Panics if this is not a struct type. - fn fields_raw(&self) -> &[StructField] { - match &self.body { - TypeBody::Struct(fields) => fields, - // Turns out we call `fields_raw` in a few places before a type may be fully finished. - // One of these is when checking for nested slices, so that check will have false - // negatives. - TypeBody::None => &[], - _ => panic!("Called DataType::fields_raw on a non-struct type: {}", self.name), - } - } - /// Retrieve the variants of this type with no modifications. /// Panics if this is not an enum type. - fn variants_raw(&self) -> &[EnumVariant] { + fn variants_raw(&self) -> Option<&[EnumVariant]> { match &self.body { - TypeBody::Enum(variants) => variants, - _ => panic!("Called DataType::variants_raw on a non-enum type: {}", self.name), + TypeBody::Enum(variants) => Some(variants), + _ => None, } } @@ -501,7 +494,7 @@ impl DataType { } /// Returns the field matching the given field name, as well as its visibility and field index. - /// Panics if this is not a struct type. + /// Always returns None if this is not a struct type. pub fn get_field( &self, field_name: &str, @@ -509,7 +502,7 @@ impl DataType { ) -> Option<(Type, ItemVisibility, usize)> { assert_eq!(self.generics.len(), generic_args.len()); - let mut fields = self.fields_raw().iter().enumerate(); + let mut fields = self.fields_raw()?.iter().enumerate(); fields.find(|(_, field)| field.name.0.contents == field_name).map(|(i, field)| { let generics = self.generics.iter().zip(generic_args); let substitutions = generics @@ -523,38 +516,38 @@ impl DataType { } /// Returns all the fields of this type, after being applied to the given generic arguments. - /// Panics if this is not a struct type. + /// Returns None if this is not a struct type. pub fn get_fields_with_visibility( &self, generic_args: &[Type], - ) -> Vec<(String, ItemVisibility, Type)> { + ) -> Option> { let substitutions = self.get_fields_substitutions(generic_args); - vecmap(self.fields_raw(), |field| { + Some(vecmap(self.fields_raw()?, |field| { let name = field.name.0.contents.clone(); (name, field.visibility, field.typ.substitute(&substitutions)) - }) + })) } - /// Retrieve the fields of this type. Panics if this is not a struct type - pub fn get_fields(&self, generic_args: &[Type]) -> Vec<(String, Type)> { + /// Retrieve the fields of this type. Returns None if this is not a field type + pub fn get_fields(&self, generic_args: &[Type]) -> Option> { let substitutions = self.get_fields_substitutions(generic_args); - vecmap(self.fields_raw(), |field| { + Some(vecmap(self.fields_raw()?, |field| { let name = field.name.0.contents.clone(); (name, field.typ.substitute(&substitutions)) - }) + })) } - /// Retrieve the variants of this type. Panics if this is not an enum type - pub fn get_variants(&self, generic_args: &[Type]) -> Vec<(String, Vec)> { + /// Retrieve the variants of this type. Returns None if this is not an enum type + pub fn get_variants(&self, generic_args: &[Type]) -> Option)>> { let substitutions = self.get_fields_substitutions(generic_args); - vecmap(self.variants_raw(), |variant| { + Some(vecmap(self.variants_raw()?, |variant| { let name = variant.name.to_string(); let args = vecmap(&variant.params, |param| param.substitute(&substitutions)); (name, args) - }) + })) } fn get_fields_substitutions( @@ -579,34 +572,35 @@ impl DataType { /// This method is almost never what is wanted for type checking or monomorphization, /// prefer to use `get_fields` whenever possible. /// - /// Panics if this is not a struct type. - pub fn get_fields_as_written(&self) -> Vec { - self.fields_raw().to_vec() + /// Returns None if this is not a struct type. + pub fn get_fields_as_written(&self) -> Option> { + Some(self.fields_raw()?.to_vec()) } /// Returns the name and raw parameters of each variant of this type. /// This will not substitute any generic arguments so a generic variant like `X` /// in `enum Foo { X(T) }` will return a `("X", Vec)` pair. /// - /// Panics if this is not an enum type. - pub fn get_variants_as_written(&self) -> Vec { - self.variants_raw().to_vec() + /// Returns None if this is not an enum type. + pub fn get_variants_as_written(&self) -> Option> { + Some(self.variants_raw()?.to_vec()) } /// Returns the field at the given index. Panics if no field exists at the given index or this /// is not a struct type. pub fn field_at(&self, index: usize) -> &StructField { - &self.fields_raw()[index] + &self.fields_raw().unwrap()[index] } /// Returns the enum variant at the given index. Panics if no field exists at the given index /// or this is not an enum type. pub fn variant_at(&self, index: usize) -> &EnumVariant { - &self.variants_raw()[index] + &self.variants_raw().unwrap()[index] } - pub fn field_names(&self) -> BTreeSet { - self.fields_raw().iter().map(|field| field.name.clone()).collect() + /// Returns each of this type's field names. Returns None if this is not a struct type. + pub fn field_names(&self) -> Option> { + Some(self.fields_raw()?.iter().map(|field| field.name.clone()).collect()) } /// Instantiate this struct type, returning a Vec of the new generic args (in @@ -1258,11 +1252,14 @@ impl Type { } Type::String(length) => length.is_valid_for_program_input(), Type::Tuple(elements) => elements.iter().all(|elem| elem.is_valid_for_program_input()), - Type::DataType(definition, generics) => definition - .borrow() - .get_fields(generics) - .into_iter() - .all(|(_, field)| field.is_valid_for_program_input()), + Type::DataType(definition, generics) => { + if let Some(fields) = definition.borrow().get_fields(generics) { + fields.into_iter().all(|(_, field)| field.is_valid_for_program_input()) + } else { + // Arbitrarily disallow enums from program input, though we may support them later + false + } + } Type::InfixExpr(lhs, _, rhs, _) => { lhs.is_valid_for_program_input() && rhs.is_valid_for_program_input() @@ -1313,11 +1310,14 @@ impl Type { } Type::String(length) => length.is_valid_non_inlined_function_input(), Type::Tuple(elements) => elements.iter().all(|elem| elem.is_valid_non_inlined_function_input()), - Type::DataType(definition, generics) => definition - .borrow() - .get_fields(generics) - .into_iter() - .all(|(_, field)| field.is_valid_non_inlined_function_input()), + Type::DataType(definition, generics) => { + if let Some(fields) = definition.borrow().get_fields(generics) { + fields.into_iter() + .all(|(_, field)| field.is_valid_non_inlined_function_input()) + } else { + false + } + } } } @@ -1365,11 +1365,13 @@ impl Type { Type::Tuple(elements) => { elements.iter().all(|elem| elem.is_valid_for_unconstrained_boundary()) } - Type::DataType(definition, generics) => definition - .borrow() - .get_fields(generics) - .into_iter() - .all(|(_, field)| field.is_valid_for_unconstrained_boundary()), + Type::DataType(definition, generics) => { + if let Some(fields) = definition.borrow().get_fields(generics) { + fields.into_iter().all(|(_, field)| field.is_valid_for_unconstrained_boundary()) + } else { + false + } + } } } @@ -1513,8 +1515,19 @@ impl Type { } Type::DataType(def, args) => { let struct_type = def.borrow(); - let fields = struct_type.get_fields(args); - fields.iter().fold(0, |acc, (_, field_type)| acc + field_type.field_count(location)) + if let Some(fields) = struct_type.get_fields(args) { + fields.iter().map(|(_, field_type)| field_type.field_count(location)).sum() + } else if let Some(variants) = struct_type.get_variants(args) { + let mut size = 1; // start with the tag size + for (_, args) in variants { + for arg in args { + size += arg.field_count(location); + } + } + size + } else { + 0 + } } Type::CheckedCast { to, .. } => to.field_count(location), Type::Alias(def, generics) => def.borrow().get_type(generics).field_count(location), @@ -1552,10 +1565,14 @@ impl Type { pub(crate) fn contains_slice(&self) -> bool { match self { Type::Slice(_) => true, - Type::DataType(struct_typ, generics) => { - let fields = struct_typ.borrow().get_fields(generics); - for field in fields.iter() { - if field.1.contains_slice() { + Type::DataType(typ, generics) => { + let typ = typ.borrow(); + if let Some(fields) = typ.get_fields(generics) { + if fields.iter().any(|(_, field)| field.contains_slice()) { + return true; + } + } else if let Some(variants) = typ.get_variants(generics) { + if variants.iter().flat_map(|(_, args)| args).any(|typ| typ.contains_slice()) { return true; } } @@ -2851,10 +2868,19 @@ impl From<&Type> for PrintableType { Type::Unit => PrintableType::Unit, Type::Constant(_, _) => unreachable!(), Type::DataType(def, ref args) => { - let struct_type = def.borrow(); - let fields = struct_type.get_fields(args); - let fields = vecmap(fields, |(name, typ)| (name, typ.into())); - PrintableType::Struct { fields, name: struct_type.name.to_string() } + let data_type = def.borrow(); + let name = data_type.name.to_string(); + + if let Some(fields) = data_type.get_fields(args) { + let fields = vecmap(fields, |(name, typ)| (name, typ.into())); + PrintableType::Struct { fields, name } + } else if let Some(variants) = data_type.get_variants(args) { + let variants = + vecmap(variants, |(name, args)| (name, vecmap(args, Into::into))); + PrintableType::Enum { name, variants } + } else { + unreachable!() + } } Type::Alias(alias, args) => alias.borrow().get_type(args).into(), Type::TraitAsType(..) => unreachable!(), diff --git a/compiler/noirc_frontend/src/locations.rs b/compiler/noirc_frontend/src/locations.rs index 33c37172b50..fcc666f13e8 100644 --- a/compiler/noirc_frontend/src/locations.rs +++ b/compiler/noirc_frontend/src/locations.rs @@ -60,7 +60,7 @@ impl NodeInterner { match reference { ReferenceId::Module(id) => self.module_attributes(&id).location, ReferenceId::Function(id) => self.function_modifiers(&id).name_location, - ReferenceId::Struct(id) | ReferenceId::Enum(id) => { + ReferenceId::Type(id) => { let typ = self.get_type(id); let typ = typ.borrow(); Location::new(typ.name.span(), typ.location.file) @@ -109,8 +109,8 @@ impl NodeInterner { ModuleDefId::FunctionId(func_id) => { self.add_function_reference(func_id, location); } - ModuleDefId::TypeId(struct_id) => { - self.add_struct_reference(struct_id, location, is_self_type); + ModuleDefId::TypeId(type_id) => { + self.add_type_reference(type_id, location, is_self_type); } ModuleDefId::TraitId(trait_id) => { self.add_trait_reference(trait_id, location, is_self_type); @@ -128,13 +128,13 @@ impl NodeInterner { self.add_reference(ReferenceId::Module(id), location, false); } - pub(crate) fn add_struct_reference( + pub(crate) fn add_type_reference( &mut self, id: TypeId, location: Location, is_self_type: bool, ) { - self.add_reference(ReferenceId::Struct(id), location, is_self_type); + self.add_reference(ReferenceId::Type(id), location, is_self_type); } pub(crate) fn add_struct_member_reference( @@ -326,25 +326,14 @@ impl NodeInterner { self.register_name_for_auto_import(name, ModuleDefId::GlobalId(id), visibility, None); } - pub(crate) fn register_struct( + pub(crate) fn register_type( &mut self, id: TypeId, name: String, visibility: ItemVisibility, parent_module_id: ModuleId, ) { - self.add_definition_location(ReferenceId::Struct(id), Some(parent_module_id)); - self.register_name_for_auto_import(name, ModuleDefId::TypeId(id), visibility, None); - } - - pub(crate) fn register_enum( - &mut self, - id: TypeId, - name: String, - visibility: ItemVisibility, - parent_module_id: ModuleId, - ) { - self.add_definition_location(ReferenceId::Enum(id), Some(parent_module_id)); + self.add_definition_location(ReferenceId::Type(id), Some(parent_module_id)); self.register_name_for_auto_import(name, ModuleDefId::TypeId(id), visibility, None); } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 3eb75037590..de8e4f6f864 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -1212,20 +1212,22 @@ impl<'interner> Monomorphizer<'interner> { Self::check_type(arg, location)?; } - if def.borrow().is_struct() { - let fields = def.borrow().get_fields(args); + let def = def.borrow(); + if let Some(fields) = def.get_fields(args) { let fields = try_vecmap(fields, |(_, field)| Self::convert_type(&field, location))?; ast::Type::Tuple(fields) - } else { + } else if let Some(variants) = def.get_variants(args) { // Enums are represented as (tag, variant1, variant2, .., variantN) let mut fields = vec![ast::Type::Field]; - for (_, variant_fields) in def.borrow().get_variants(args) { + for (_, variant_fields) in variants { let variant_fields = try_vecmap(variant_fields, |typ| Self::convert_type(&typ, location))?; fields.push(ast::Type::Tuple(variant_fields)); } ast::Type::Tuple(fields) + } else { + unreachable!("Data type has no body") } } @@ -2198,7 +2200,7 @@ fn unwrap_struct_type( Monomorphizer::check_type(arg, location)?; } - Ok(def.borrow().get_fields(&args)) + Ok(def.borrow().get_fields(&args).unwrap()) } other => unreachable!("unwrap_struct_type: expected struct, found {:?}", other), } @@ -2215,7 +2217,7 @@ fn unwrap_enum_type( Monomorphizer::check_type(arg, location)?; } - Ok(def.borrow().get_variants(&args)) + Ok(def.borrow().get_variants(&args).unwrap()) } other => unreachable!("unwrap_enum_type: expected enum, found {:?}", other), } diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 3d74cc1bf37..024a98d8475 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -299,9 +299,8 @@ pub enum DependencyId { #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum ReferenceId { Module(ModuleId), - Struct(TypeId), + Type(TypeId), StructMember(TypeId, usize), - Enum(TypeId), EnumVariant(TypeId, usize), Trait(TraitId), Global(GlobalId), diff --git a/compiler/noirc_frontend/src/resolve_locations.rs b/compiler/noirc_frontend/src/resolve_locations.rs index 1b904f653bd..4daf088a2f1 100644 --- a/compiler/noirc_frontend/src/resolve_locations.rs +++ b/compiler/noirc_frontend/src/resolve_locations.rs @@ -155,7 +155,7 @@ impl NodeInterner { }; let struct_type = lhs_self_struct.borrow(); - let field_names = struct_type.field_names(); + let field_names = struct_type.field_names()?; field_names.iter().find(|field_name| field_name.0 == expr_rhs.0).map(|found_field_name| { Location::new(found_field_name.span(), struct_type.location.file) diff --git a/compiler/noirc_printable_type/src/lib.rs b/compiler/noirc_printable_type/src/lib.rs index eb74d2470fb..1831180d0ab 100644 --- a/compiler/noirc_printable_type/src/lib.rs +++ b/compiler/noirc_printable_type/src/lib.rs @@ -36,6 +36,10 @@ pub enum PrintableType { name: String, fields: Vec<(String, PrintableType)>, }, + Enum { + name: String, + variants: Vec<(String, Vec)>, + }, String { length: u32, }, diff --git a/tooling/lsp/src/modules.rs b/tooling/lsp/src/modules.rs index b023f3886c3..758322fa4bc 100644 --- a/tooling/lsp/src/modules.rs +++ b/tooling/lsp/src/modules.rs @@ -16,7 +16,7 @@ pub(crate) fn module_def_id_to_reference_id(module_def_id: ModuleDefId) -> Refer match module_def_id { ModuleDefId::ModuleId(id) => ReferenceId::Module(id), ModuleDefId::FunctionId(id) => ReferenceId::Function(id), - ModuleDefId::TypeId(id) => ReferenceId::Struct(id), + ModuleDefId::TypeId(id) => ReferenceId::Type(id), ModuleDefId::TypeAliasId(id) => ReferenceId::Alias(id), ModuleDefId::TraitId(id) => ReferenceId::Trait(id), ModuleDefId::GlobalId(id) => ReferenceId::Global(id), diff --git a/tooling/lsp/src/requests/code_action/fill_struct_fields.rs b/tooling/lsp/src/requests/code_action/fill_struct_fields.rs index 7a4d562e402..fc8be7c5163 100644 --- a/tooling/lsp/src/requests/code_action/fill_struct_fields.rs +++ b/tooling/lsp/src/requests/code_action/fill_struct_fields.rs @@ -20,25 +20,23 @@ impl<'a> CodeActionFinder<'a> { }; let location = Location::new(path.span, self.file); - let Some(ReferenceId::Struct(struct_id)) = self.interner.find_referenced(location) else { + let Some(ReferenceId::Type(type_id)) = self.interner.find_referenced(location) else { return; }; - let struct_type = self.interner.get_type(struct_id); - let struct_type = struct_type.borrow(); + let typ = self.interner.get_type(type_id); + let typ = typ.borrow(); // First get all of the struct's fields - let mut fields = struct_type.get_fields_as_written(); + let Some(mut fields) = typ.get_fields_as_written() else { + return; + }; // Remove the ones that already exists in the constructor for (constructor_field, _) in &constructor.fields { fields.retain(|field| field.name.0.contents != constructor_field.0.contents); } - if fields.is_empty() { - return; - } - // Some fields are missing. Let's suggest a quick fix that adds them. let bytes = self.source.as_bytes(); let right_brace_index = span.end() as usize - 1; diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index 9948a29691e..0c51772935a 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -199,7 +199,7 @@ impl<'a> NodeFinder<'a> { }; let location = Location::new(span, self.file); - let Some(ReferenceId::Struct(struct_id)) = self.interner.find_referenced(location) else { + let Some(ReferenceId::Type(struct_id)) = self.interner.find_referenced(location) else { return; }; @@ -207,8 +207,11 @@ impl<'a> NodeFinder<'a> { let struct_type = struct_type.borrow(); // First get all of the struct's fields - let mut fields: Vec<_> = - struct_type.get_fields_as_written().into_iter().enumerate().collect(); + let Some(fields) = struct_type.get_fields_as_written() else { + return; + }; + + let mut fields = fields.into_iter().enumerate().collect::>(); // Remove the ones that already exists in the constructor for (used_name, _) in &constructor_expression.fields { @@ -805,9 +808,11 @@ impl<'a> NodeFinder<'a> { prefix: &str, self_prefix: bool, ) { - for (field_index, (name, visibility, typ)) in - struct_type.get_fields_with_visibility(generics).iter().enumerate() - { + let Some(fields) = struct_type.get_fields_with_visibility(generics) else { + return; + }; + + for (field_index, (name, visibility, typ)) in fields.iter().enumerate() { if !struct_member_is_visible(struct_type.id, *visibility, self.module_id, self.def_maps) { continue; @@ -1953,8 +1958,7 @@ fn name_matches(name: &str, prefix: &str) -> bool { fn module_def_id_from_reference_id(reference_id: ReferenceId) -> Option { match reference_id { ReferenceId::Module(module_id) => Some(ModuleDefId::ModuleId(module_id)), - ReferenceId::Struct(struct_id) => Some(ModuleDefId::TypeId(struct_id)), - ReferenceId::Enum(enum_id) => Some(ModuleDefId::TypeId(enum_id)), + ReferenceId::Type(struct_id) => Some(ModuleDefId::TypeId(struct_id)), ReferenceId::Trait(trait_id) => Some(ModuleDefId::TraitId(trait_id)), ReferenceId::Function(func_id) => Some(ModuleDefId::FunctionId(func_id)), ReferenceId::Alias(type_alias_id) => Some(ModuleDefId::TypeAliasId(type_alias_id)), diff --git a/tooling/lsp/src/requests/completion/completion_items.rs b/tooling/lsp/src/requests/completion/completion_items.rs index c8ae16bf1f4..039b745172b 100644 --- a/tooling/lsp/src/requests/completion/completion_items.rs +++ b/tooling/lsp/src/requests/completion/completion_items.rs @@ -113,7 +113,7 @@ impl<'a> NodeFinder<'a> { fn struct_completion_item(&self, name: String, struct_id: TypeId) -> CompletionItem { let completion_item = simple_completion_item(name.clone(), CompletionItemKind::STRUCT, Some(name)); - self.completion_item_with_doc_comments(ReferenceId::Struct(struct_id), completion_item) + self.completion_item_with_doc_comments(ReferenceId::Type(struct_id), completion_item) } pub(super) fn struct_field_completion_item( @@ -288,10 +288,10 @@ impl<'a> NodeFinder<'a> { } else { false }; + let description = func_meta_type_to_string(func_meta, name, func_self_type.is_some()); let name = if self_prefix { format!("self.{}", name) } else { name.clone() }; let name = if is_macro_call { format!("{}!", name) } else { name }; let name = &name; - let description = func_meta_type_to_string(func_meta, func_self_type.is_some()); let mut has_arguments = false; let completion_item = match function_completion_kind { @@ -351,7 +351,16 @@ impl<'a> NodeFinder<'a> { self.auto_import_trait_if_trait_method(func_id, trait_info, &mut completion_item); - self.completion_item_with_doc_comments(ReferenceId::Function(func_id), completion_item) + if let (Some(type_id), Some(variant_index)) = + (func_meta.type_id, func_meta.enum_variant_index) + { + self.completion_item_with_doc_comments( + ReferenceId::EnumVariant(type_id, variant_index), + completion_item, + ) + } else { + self.completion_item_with_doc_comments(ReferenceId::Function(func_id), completion_item) + } } fn auto_import_trait_if_trait_method( @@ -419,6 +428,8 @@ impl<'a> NodeFinder<'a> { function_kind: FunctionKind, skip_first_argument: bool, ) -> String { + let is_enum_variant = func_meta.enum_variant_index.is_some(); + let mut text = String::new(); text.push_str(name); text.push('('); @@ -448,7 +459,11 @@ impl<'a> NodeFinder<'a> { text.push_str("${"); text.push_str(&index.to_string()); text.push(':'); - self.hir_pattern_to_argument(pattern, &mut text); + if is_enum_variant { + text.push_str("()"); + } else { + self.hir_pattern_to_argument(pattern, &mut text); + } text.push('}'); index += 1; @@ -512,18 +527,25 @@ pub(super) fn trait_impl_method_completion_item( snippet_completion_item(label, CompletionItemKind::METHOD, insert_text, None) } -fn func_meta_type_to_string(func_meta: &FuncMeta, has_self_type: bool) -> String { +fn func_meta_type_to_string(func_meta: &FuncMeta, name: &str, has_self_type: bool) -> String { let mut typ = &func_meta.typ; if let Type::Forall(_, typ_) = typ { typ = typ_; } + let is_enum_variant = func_meta.enum_variant_index.is_some(); + if let Type::Function(args, ret, _env, unconstrained) = typ { let mut string = String::new(); - if *unconstrained { - string.push_str("unconstrained "); + if is_enum_variant { + string.push_str(name); + string.push('('); + } else { + if *unconstrained { + string.push_str("unconstrained "); + } + string.push_str("fn("); } - string.push_str("fn("); for (index, arg) in args.iter().enumerate() { if index > 0 { string.push_str(", "); @@ -536,13 +558,16 @@ fn func_meta_type_to_string(func_meta: &FuncMeta, has_self_type: bool) -> String } string.push(')'); - let ret: &Type = ret; - if let Type::Unit = ret { - // Nothing - } else { - string.push_str(" -> "); - string.push_str(&ret.to_string()); + if !is_enum_variant { + let ret: &Type = ret; + if let Type::Unit = ret { + // Nothing + } else { + string.push_str(" -> "); + string.push_str(&ret.to_string()); + } } + string } else { typ.to_string() diff --git a/tooling/lsp/src/requests/completion/tests.rs b/tooling/lsp/src/requests/completion/tests.rs index 8ff568e3c26..a3cd6b0d024 100644 --- a/tooling/lsp/src/requests/completion/tests.rs +++ b/tooling/lsp/src/requests/completion/tests.rs @@ -20,8 +20,8 @@ mod completion_tests { use lsp_types::{ CompletionItem, CompletionItemKind, CompletionItemLabelDetails, CompletionParams, - CompletionResponse, DidOpenTextDocumentParams, PartialResultParams, Position, - TextDocumentIdentifier, TextDocumentItem, TextDocumentPositionParams, + CompletionResponse, DidOpenTextDocumentParams, Documentation, PartialResultParams, + Position, TextDocumentIdentifier, TextDocumentItem, TextDocumentPositionParams, WorkDoneProgressParams, }; use tokio::test; @@ -3077,4 +3077,35 @@ fn main() { "#; assert_eq!(new_code, expected); } + + #[test] + async fn test_suggests_enum_variant_differently_than_a_function_call() { + let src = r#" + enum Enum { + /// Some docs + Variant(Field, i32) + } + + fn foo() { + Enum::Var>|< + } + "#; + let items = get_completions(src).await; + assert_eq!(items.len(), 1); + + let item = &items[0]; + assert_eq!(item.label, "Variant(…)".to_string()); + + let details = item.label_details.as_ref().unwrap(); + assert_eq!(details.description, Some("Variant(Field, i32)".to_string())); + + assert_eq!(item.detail, Some("Variant(Field, i32)".to_string())); + + assert_eq!(item.insert_text, Some("Variant(${1:()}, ${2:()})".to_string())); + + let Documentation::MarkupContent(markdown) = item.documentation.as_ref().unwrap() else { + panic!("Expected markdown docs"); + }; + assert!(markdown.value.contains("Some docs")); + } } diff --git a/tooling/lsp/src/requests/hover.rs b/tooling/lsp/src/requests/hover.rs index 8b9c0bb4b28..60c2a686a62 100644 --- a/tooling/lsp/src/requests/hover.rs +++ b/tooling/lsp/src/requests/hover.rs @@ -16,7 +16,8 @@ use noirc_frontend::{ DefinitionId, DefinitionKind, ExprId, FuncId, GlobalId, NodeInterner, ReferenceId, TraitId, TraitImplKind, TypeAliasId, TypeId, }, - DataType, Generics, Shared, Type, TypeAlias, TypeBinding, TypeVariable, + DataType, EnumVariant, Generics, Shared, StructField, Type, TypeAlias, TypeBinding, + TypeVariable, }; use crate::{ @@ -73,11 +74,10 @@ pub(crate) fn on_hover_request( fn format_reference(reference: ReferenceId, args: &ProcessRequestCallbackArgs) -> Option { match reference { ReferenceId::Module(id) => format_module(id, args), - ReferenceId::Struct(id) => Some(format_struct(id, args)), + ReferenceId::Type(id) => Some(format_type(id, args)), ReferenceId::StructMember(id, field_index) => { Some(format_struct_member(id, field_index, args)) } - ReferenceId::Enum(id) => Some(format_enum(id, args)), ReferenceId::EnumVariant(id, variant_index) => { Some(format_enum_variant(id, variant_index, args)) } @@ -126,20 +126,33 @@ fn format_module(id: ModuleId, args: &ProcessRequestCallbackArgs) -> Option String { - let struct_type = args.interner.get_type(id); - let struct_type = struct_type.borrow(); +fn format_type(id: TypeId, args: &ProcessRequestCallbackArgs) -> String { + let typ = args.interner.get_type(id); + let typ = typ.borrow(); + if let Some(fields) = typ.get_fields_as_written() { + format_struct(&typ, fields, args) + } else if let Some(variants) = typ.get_variants_as_written() { + format_enum(&typ, variants, args) + } else { + unreachable!("Type should either be a struct or an enum") + } +} +fn format_struct( + typ: &DataType, + fields: Vec, + args: &ProcessRequestCallbackArgs, +) -> String { let mut string = String::new(); - if format_parent_module(ReferenceId::Struct(id), args, &mut string) { + if format_parent_module(ReferenceId::Type(typ.id), args, &mut string) { string.push('\n'); } string.push_str(" "); string.push_str("struct "); - string.push_str(&struct_type.name.0.contents); - format_generics(&struct_type.generics, &mut string); + string.push_str(&typ.name.0.contents); + format_generics(&typ.generics, &mut string); string.push_str(" {\n"); - for field in struct_type.get_fields_as_written() { + for field in fields { string.push_str(" "); string.push_str(&field.name.0.contents); string.push_str(": "); @@ -148,17 +161,18 @@ fn format_struct(id: TypeId, args: &ProcessRequestCallbackArgs) -> String { } string.push_str(" }"); - append_doc_comments(args.interner, ReferenceId::Struct(id), &mut string); + append_doc_comments(args.interner, ReferenceId::Type(typ.id), &mut string); string } -fn format_enum(id: TypeId, args: &ProcessRequestCallbackArgs) -> String { - let typ = args.interner.get_type(id); - let typ = typ.borrow(); - +fn format_enum( + typ: &DataType, + variants: Vec, + args: &ProcessRequestCallbackArgs, +) -> String { let mut string = String::new(); - if format_parent_module(ReferenceId::Enum(id), args, &mut string) { + if format_parent_module(ReferenceId::Type(typ.id), args, &mut string) { string.push('\n'); } string.push_str(" "); @@ -166,7 +180,7 @@ fn format_enum(id: TypeId, args: &ProcessRequestCallbackArgs) -> String { string.push_str(&typ.name.0.contents); format_generics(&typ.generics, &mut string); string.push_str(" {\n"); - for field in typ.get_variants_as_written() { + for field in variants { string.push_str(" "); string.push_str(&field.name.0.contents); @@ -181,7 +195,7 @@ fn format_enum(id: TypeId, args: &ProcessRequestCallbackArgs) -> String { } string.push_str(" }"); - append_doc_comments(args.interner, ReferenceId::Enum(id), &mut string); + append_doc_comments(args.interner, ReferenceId::Type(typ.id), &mut string); string } @@ -196,7 +210,7 @@ fn format_struct_member( let field = struct_type.field_at(field_index); let mut string = String::new(); - if format_parent_module(ReferenceId::Struct(id), args, &mut string) { + if format_parent_module(ReferenceId::Type(id), args, &mut string) { string.push_str("::"); } string.push_str(&struct_type.name.0.contents); @@ -222,7 +236,7 @@ fn format_enum_variant( let variant = enum_type.variant_at(field_index); let mut string = String::new(); - if format_parent_module(ReferenceId::Enum(id), args, &mut string) { + if format_parent_module(ReferenceId::Type(id), args, &mut string) { string.push_str("::"); } string.push_str(&enum_type.name.0.contents); @@ -377,9 +391,19 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { let func_name_definition_id = args.interner.definition(func_meta.name.id); + let enum_variant = match (func_meta.type_id, func_meta.enum_variant_index) { + (Some(type_id), Some(index)) => Some((type_id, index)), + _ => None, + }; + + let reference_id = if let Some((type_id, variant_index)) = enum_variant { + ReferenceId::EnumVariant(type_id, variant_index) + } else { + ReferenceId::Function(id) + }; + let mut string = String::new(); - let formatted_parent_module = - format_parent_module(ReferenceId::Function(id), args, &mut string); + let formatted_parent_module = format_parent_module(reference_id, args, &mut string); let formatted_parent_type = if let Some(trait_impl_id) = func_meta.trait_impl { let trait_impl = args.interner.get_trait_implementation(trait_impl_id); @@ -444,21 +468,23 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { string.push_str("::"); } string.push_str(&data_type.name.0.contents); - string.push('\n'); - string.push_str(" "); - string.push_str("impl"); + if enum_variant.is_none() { + string.push('\n'); + string.push_str(" "); + string.push_str("impl"); - let impl_generics: Vec<_> = func_meta - .all_generics - .iter() - .take(func_meta.all_generics.len() - func_meta.direct_generics.len()) - .cloned() - .collect(); - format_generics(&impl_generics, &mut string); + let impl_generics: Vec<_> = func_meta + .all_generics + .iter() + .take(func_meta.all_generics.len() - func_meta.direct_generics.len()) + .cloned() + .collect(); + format_generics(&impl_generics, &mut string); - string.push(' '); - string.push_str(&data_type.name.0.contents); - format_generic_names(&impl_generics, &mut string); + string.push(' '); + string.push_str(&data_type.name.0.contents); + format_generic_names(&impl_generics, &mut string); + } true } else { @@ -485,7 +511,9 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { let func_name = &func_name_definition_id.name; - string.push_str("fn "); + if enum_variant.is_none() { + string.push_str("fn "); + } string.push_str(func_name); format_generics(&func_meta.direct_generics, &mut string); string.push('('); @@ -498,15 +526,19 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { string.push_str("&mut "); } - format_pattern(pattern, args.interner, &mut string); + if enum_variant.is_some() { + string.push_str(&format!("{}", typ)); + } else { + format_pattern(pattern, args.interner, &mut string); - // Don't add type for `self` param - if !is_self { - string.push_str(": "); - if matches!(visibility, Visibility::Public) { - string.push_str("pub "); + // Don't add type for `self` param + if !is_self { + string.push_str(": "); + if matches!(visibility, Visibility::Public) { + string.push_str("pub "); + } + string.push_str(&format!("{}", typ)); } - string.push_str(&format!("{}", typ)); } if index != parameters.len() - 1 { @@ -516,28 +548,34 @@ fn format_function(id: FuncId, args: &ProcessRequestCallbackArgs) -> String { string.push(')'); - let return_type = func_meta.return_type(); - match return_type { - Type::Unit => (), - _ => { - string.push_str(" -> "); - string.push_str(&format!("{}", return_type)); + if enum_variant.is_none() { + let return_type = func_meta.return_type(); + match return_type { + Type::Unit => (), + _ => { + string.push_str(" -> "); + string.push_str(&format!("{}", return_type)); + } } - } - string.push_str(&go_to_type_links(return_type, args.interner, args.files)); + string.push_str(&go_to_type_links(return_type, args.interner, args.files)); + } - let had_doc_comments = - append_doc_comments(args.interner, ReferenceId::Function(id), &mut string); - if !had_doc_comments { - // If this function doesn't have doc comments, but it's a trait impl method, - // use the trait method doc comments. - if let Some(trait_impl_id) = func_meta.trait_impl { - let trait_impl = args.interner.get_trait_implementation(trait_impl_id); - let trait_impl = trait_impl.borrow(); - let trait_ = args.interner.get_trait(trait_impl.trait_id); - if let Some(func_id) = trait_.method_ids.get(func_name) { - append_doc_comments(args.interner, ReferenceId::Function(*func_id), &mut string); + if enum_variant.is_some() { + append_doc_comments(args.interner, reference_id, &mut string); + } else { + let had_doc_comments = append_doc_comments(args.interner, reference_id, &mut string); + if !had_doc_comments { + // If this function doesn't have doc comments, but it's a trait impl method, + // use the trait method doc comments. + if let Some(trait_impl_id) = func_meta.trait_impl { + let trait_impl = args.interner.get_trait_implementation(trait_impl_id); + let trait_impl = trait_impl.borrow(); + let trait_ = args.interner.get_trait(trait_impl.trait_id); + if let Some(func_id) = trait_.method_ids.get(func_name) { + let reference_id = ReferenceId::Function(*func_id); + append_doc_comments(args.interner, reference_id, &mut string); + } } } } @@ -1256,4 +1294,67 @@ mod hover_tests { .await; assert!(hover_text.contains("fn mut_self(&mut self)")); } + + #[test] + async fn hover_on_empty_enum_type() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 100, character: 8 }) + .await; + assert!(hover_text.contains( + " two + enum EmptyColor { + } + +--- + + Red, blue, etc." + )); + } + + #[test] + async fn hover_on_non_empty_enum_type() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 103, character: 8 }) + .await; + assert!(hover_text.contains( + " two + enum Color { + Red(Field), + } + +--- + + Red, blue, etc." + )); + } + + #[test] + async fn hover_on_enum_variant() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 105, character: 6 }) + .await; + assert!(hover_text.contains( + " two::Color + Red(Field) + +--- + + Like a tomato" + )); + } + + #[test] + async fn hover_on_enum_variant_in_call() { + let hover_text = + get_hover_text("workspace", "two/src/lib.nr", Position { line: 109, character: 12 }) + .await; + assert!(hover_text.contains( + " two::Color + Red(Field) + +--- + + Like a tomato" + )); + } } diff --git a/tooling/lsp/src/requests/inlay_hint.rs b/tooling/lsp/src/requests/inlay_hint.rs index 1798f845a31..cbf4ed26ef9 100644 --- a/tooling/lsp/src/requests/inlay_hint.rs +++ b/tooling/lsp/src/requests/inlay_hint.rs @@ -109,8 +109,7 @@ impl<'a> InlayHintCollector<'a> { self.push_type_hint(lsp_location, &variant_type, false, include_colon); } ReferenceId::Module(_) - | ReferenceId::Struct(_) - | ReferenceId::Enum(_) + | ReferenceId::Type(_) | ReferenceId::Trait(_) | ReferenceId::Function(_) | ReferenceId::Alias(_) @@ -174,6 +173,11 @@ impl<'a> InlayHintCollector<'a> { if let Some(ReferenceId::Function(func_id)) = referenced { let func_meta = self.interner.function_meta(&func_id); + // No hints for enum variants + if func_meta.enum_variant_index.is_some() { + return; + } + let mut parameters = func_meta.parameters.iter().peekable(); let mut parameters_count = func_meta.parameters.len(); diff --git a/tooling/lsp/src/requests/signature_help.rs b/tooling/lsp/src/requests/signature_help.rs index c0d40656c19..9847586fad1 100644 --- a/tooling/lsp/src/requests/signature_help.rs +++ b/tooling/lsp/src/requests/signature_help.rs @@ -122,10 +122,22 @@ impl<'a> SignatureFinder<'a> { active_parameter: Option, has_self: bool, ) -> SignatureInformation { + let enum_type_id = match (func_meta.type_id, func_meta.enum_variant_index) { + (Some(type_id), Some(_)) => Some(type_id), + _ => None, + }; + let mut label = String::new(); let mut parameters = Vec::new(); - label.push_str("fn "); + if let Some(enum_type_id) = enum_type_id { + label.push_str("enum "); + label.push_str(&self.interner.get_type(enum_type_id).borrow().name.0.contents); + label.push_str("::"); + } else { + label.push_str("fn "); + } + label.push_str(name); label.push('('); for (index, (pattern, typ, _)) in func_meta.parameters.0.iter().enumerate() { @@ -142,8 +154,10 @@ impl<'a> SignatureFinder<'a> { } else { let parameter_start = label.chars().count(); - self.hir_pattern_to_argument(pattern, &mut label); - label.push_str(": "); + if enum_type_id.is_none() { + self.hir_pattern_to_argument(pattern, &mut label); + label.push_str(": "); + } label.push_str(&typ.to_string()); let parameter_end = label.chars().count(); @@ -159,11 +173,13 @@ impl<'a> SignatureFinder<'a> { } label.push(')'); - match &func_meta.return_type { - FunctionReturnType::Default(_) => (), - FunctionReturnType::Ty(typ) => { - label.push_str(" -> "); - label.push_str(&typ.to_string()); + if enum_type_id.is_none() { + match &func_meta.return_type { + FunctionReturnType::Default(_) => (), + FunctionReturnType::Ty(typ) => { + label.push_str(" -> "); + label.push_str(&typ.to_string()); + } } } diff --git a/tooling/lsp/src/requests/signature_help/tests.rs b/tooling/lsp/src/requests/signature_help/tests.rs index 4b3f3c38156..0aedc256652 100644 --- a/tooling/lsp/src/requests/signature_help/tests.rs +++ b/tooling/lsp/src/requests/signature_help/tests.rs @@ -240,4 +240,31 @@ mod signature_help_tests { assert_eq!(signature.active_parameter, Some(0)); } + + #[test] + async fn test_signature_help_for_enum_variant() { + let src = r#" + enum Enum { + Variant(Field, i32) + } + + fn bar() { + Enum::Variant(>|<(), ()); + } + "#; + + let signature_help = get_signature_help(src).await; + assert_eq!(signature_help.signatures.len(), 1); + + let signature = &signature_help.signatures[0]; + assert_eq!(signature.label, "enum Enum::Variant(Field, i32)"); + + let params = signature.parameters.as_ref().unwrap(); + assert_eq!(params.len(), 2); + + check_label(&signature.label, ¶ms[0].label, "Field"); + check_label(&signature.label, ¶ms[1].label, "i32"); + + assert_eq!(signature.active_parameter, Some(0)); + } } diff --git a/tooling/lsp/test_programs/workspace/two/src/lib.nr b/tooling/lsp/test_programs/workspace/two/src/lib.nr index aacc4508756..0baeb83d5c1 100644 --- a/tooling/lsp/test_programs/workspace/two/src/lib.nr +++ b/tooling/lsp/test_programs/workspace/two/src/lib.nr @@ -96,3 +96,16 @@ impl TraitWithDocs for Field { impl Foo { fn mut_self(&mut self) {} } + +/// Red, blue, etc. +enum EmptyColor {} + +/// Red, blue, etc. +enum Color { + /// Like a tomato + Red(Field), +} + +fn test_enum() -> Color { + Color::Red(1) +} diff --git a/tooling/noirc_abi/src/printable_type.rs b/tooling/noirc_abi/src/printable_type.rs index a81eb0ce8f6..e13cab06e9f 100644 --- a/tooling/noirc_abi/src/printable_type.rs +++ b/tooling/noirc_abi/src/printable_type.rs @@ -74,5 +74,15 @@ pub fn decode_value( decode_value(field_iterator, typ) } PrintableType::Unit => PrintableValue::Field(F::zero()), + PrintableType::Enum { name: _, variants } => { + let tag = field_iterator.next().unwrap(); + let tag_value = tag.to_u128() as usize; + + let (_name, variant_types) = &variants[tag_value]; + PrintableValue::Vec { + array_elements: vecmap(variant_types, |typ| decode_value(field_iterator, typ)), + is_slice: false, + } + } } }