diff --git a/aztec_macros/src/transforms/contract_interface.rs b/aztec_macros/src/transforms/contract_interface.rs index dd3ec7f6a75..7a8a7187857 100644 --- a/aztec_macros/src/transforms/contract_interface.rs +++ b/aztec_macros/src/transforms/contract_interface.rs @@ -324,101 +324,105 @@ pub fn update_fn_signatures_in_contract_interface( }); if let Some(interface_struct) = maybe_interface_struct { - let methods = context.def_interner.get_struct_methods(interface_struct.borrow().id); - - for func_id in methods.iter().flat_map(|methods| methods.direct.iter()) { - let name = context.def_interner.function_name(func_id); - let fn_parameters = &context.def_interner.function_meta(func_id).parameters.clone(); - - if name == "at" || name == "interface" || name == "storage" { - continue; - } - - let fn_signature_hash = compute_fn_signature_hash( - name, - &fn_parameters - .iter() - .skip(1) - .map(|(_, typ, _)| typ.clone()) - .collect::>(), - ); - let hir_func = context.def_interner.function(func_id).block(&context.def_interner); - - let function_selector_statement = context.def_interner.statement( - hir_func.statements().get(hir_func.statements().len() - 2).ok_or(( - AztecMacroError::CouldNotGenerateContractInterface { - secondary_message: Some( - "Function signature statement not found, invalid body length" - .to_string(), - ), - }, - file_id, - ))?, - ); - let function_selector_expression_id = match function_selector_statement { - HirStatement::Let(let_statement) => Ok(let_statement.expression), - _ => Err(( - AztecMacroError::CouldNotGenerateContractInterface { - secondary_message: Some( - "Function selector statement must be an expression".to_string(), - ), - }, - file_id, - )), - }?; - let function_selector_expression = - context.def_interner.expression(&function_selector_expression_id); - - let current_fn_signature_expression_id = match function_selector_expression { - HirExpression::Call(call_expr) => Ok(call_expr.arguments[0]), - _ => Err(( - AztecMacroError::CouldNotGenerateContractInterface { - secondary_message: Some( - "Function selector argument expression must be call expression" - .to_string(), - ), - }, - file_id, - )), - }?; - - let current_fn_signature_expression = - context.def_interner.expression(¤t_fn_signature_expression_id); + if let Some(methods) = + context.def_interner.get_struct_methods(interface_struct.borrow().id).cloned() + { + for func_id in methods.iter().flat_map(|(_name, methods)| methods.direct.iter()) { + let name = context.def_interner.function_name(func_id); + let fn_parameters = + &context.def_interner.function_meta(func_id).parameters.clone(); + + if name == "at" || name == "interface" || name == "storage" { + continue; + } - match current_fn_signature_expression { - HirExpression::Literal(HirLiteral::Integer(value, _)) => { - if !value.is_zero() { - Err(( + let fn_signature_hash = compute_fn_signature_hash( + name, + &fn_parameters + .iter() + .skip(1) + .map(|(_, typ, _)| typ.clone()) + .collect::>(), + ); + let hir_func = + context.def_interner.function(func_id).block(&context.def_interner); + + let function_selector_statement = context.def_interner.statement( + hir_func.statements().get(hir_func.statements().len() - 2).ok_or(( + AztecMacroError::CouldNotGenerateContractInterface { + secondary_message: Some( + "Function signature statement not found, invalid body length" + .to_string(), + ), + }, + file_id, + ))?, + ); + let function_selector_expression_id = match function_selector_statement { + HirStatement::Let(let_statement) => Ok(let_statement.expression), + _ => Err(( + AztecMacroError::CouldNotGenerateContractInterface { + secondary_message: Some( + "Function selector statement must be an expression".to_string(), + ), + }, + file_id, + )), + }?; + let function_selector_expression = + context.def_interner.expression(&function_selector_expression_id); + + let current_fn_signature_expression_id = match function_selector_expression { + HirExpression::Call(call_expr) => Ok(call_expr.arguments[0]), + _ => Err(( + AztecMacroError::CouldNotGenerateContractInterface { + secondary_message: Some( + "Function selector argument expression must be call expression" + .to_string(), + ), + }, + file_id, + )), + }?; + + let current_fn_signature_expression = + context.def_interner.expression(¤t_fn_signature_expression_id); + + match current_fn_signature_expression { + HirExpression::Literal(HirLiteral::Integer(value, _)) => { + if !value.is_zero() { + Err(( AztecMacroError::CouldNotGenerateContractInterface { secondary_message: Some( "Function signature argument must be a placeholder with value 0".to_string()), }, file_id, )) - } else { - Ok(()) + } else { + Ok(()) + } } - } - _ => Err(( - AztecMacroError::CouldNotGenerateContractInterface { - secondary_message: Some( - "Function signature argument must be a literal field element" - .to_string(), - ), + _ => Err(( + AztecMacroError::CouldNotGenerateContractInterface { + secondary_message: Some( + "Function signature argument must be a literal field element" + .to_string(), + ), + }, + file_id, + )), + }?; + + context.def_interner.update_expression( + current_fn_signature_expression_id, + |expr| { + *expr = HirExpression::Literal(HirLiteral::Integer( + FieldElement::from(fn_signature_hash as u128), + false, + )) }, - file_id, - )), - }?; - - context.def_interner.update_expression( - current_fn_signature_expression_id, - |expr| { - *expr = HirExpression::Literal(HirLiteral::Integer( - FieldElement::from(fn_signature_hash as u128), - false, - )) - }, - ); + ); + } } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 177d23c74dd..2aba8b918d6 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -665,12 +665,12 @@ impl std::fmt::Display for Type { Type::Function(args, ret, env) => { let closure_env_text = match **env { Type::Unit => "".to_string(), - _ => format!(" with env {env}"), + _ => format!("[{env}]"), }; let args = vecmap(args.iter(), ToString::to_string); - write!(f, "fn({}) -> {ret}{closure_env_text}", args.join(", ")) + write!(f, "fn{closure_env_text}({}) -> {ret}", args.join(", ")) } Type::MutableReference(element) => { write!(f, "&mut {element}") diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 43e742b940e..9b9bfd463b9 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -179,10 +179,10 @@ pub struct NodeInterner { /// may have both `impl Struct { fn foo(){} }` and `impl Struct { fn foo(){} }`. /// If this happens, the returned Vec will have 2 entries and we'll need to further /// disambiguate them by checking the type of each function. - struct_methods: HashMap<(StructId, String), Methods>, + struct_methods: HashMap>, /// Methods on primitive types defined in the stdlib. - primitive_methods: HashMap<(TypeMethodKey, String), Methods>, + primitive_methods: HashMap>, // For trait implementation functions, this is their self type and trait they belong to func_id_to_trait: HashMap, @@ -1131,22 +1131,27 @@ impl NodeInterner { self.structs[&id].clone() } - pub fn get_struct_methods(&self, id: StructId) -> Vec { - self.struct_methods - .keys() - .filter_map(|(key_id, name)| { - if key_id == &id { - Some( - self.struct_methods - .get(&(*key_id, name.clone())) - .expect("get_struct_methods given invalid StructId") - .clone(), - ) - } else { - None - } - }) - .collect() + pub fn get_struct_methods(&self, id: StructId) -> Option<&HashMap> { + self.struct_methods.get(&id) + } + + fn get_primitive_methods(&self, key: TypeMethodKey) -> Option<&HashMap> { + self.primitive_methods.get(&key) + } + + pub fn get_type_methods(&self, typ: &Type) -> Option<&HashMap> { + match typ { + Type::Struct(struct_type, _) => { + let struct_type = struct_type.borrow(); + self.get_struct_methods(struct_type.id) + } + Type::Alias(type_alias, generics) => { + let type_alias = type_alias.borrow(); + let typ = type_alias.get_type(generics); + self.get_type_methods(&typ) + } + _ => get_type_method_key(typ).and_then(|key| self.get_primitive_methods(key)), + } } pub fn get_trait(&self, id: TraitId) -> &Trait { @@ -1300,8 +1305,12 @@ impl NodeInterner { return Some(existing); } - let key = (id, method_name); - self.struct_methods.entry(key).or_default().add_method(method_id, is_trait_method); + self.struct_methods + .entry(id) + .or_default() + .entry(method_name) + .or_default() + .add_method(method_id, is_trait_method); None } Type::Error => None, @@ -1314,7 +1323,9 @@ impl NodeInterner { unreachable!("Cannot add a method to the unsupported type '{}'", other) }); self.primitive_methods - .entry((key, method_name)) + .entry(key) + .or_default() + .entry(method_name) .or_default() .add_method(method_id, is_trait_method); None @@ -1648,7 +1659,7 @@ impl NodeInterner { method_name: &str, force_type_check: bool, ) -> Option { - let methods = self.struct_methods.get(&(id, method_name.to_owned())); + let methods = self.struct_methods.get(&id).and_then(|h| h.get(method_name)); // If there is only one method, just return it immediately. // It will still be typechecked later. @@ -1673,8 +1684,8 @@ impl NodeInterner { } else { // Failed to find a match for the type in question, switch to looking at impls // for all types `T`, e.g. `impl Foo for T` - let key = &(TypeMethodKey::Generic, method_name.to_owned()); - let global_methods = self.primitive_methods.get(key)?; + let global_methods = + self.primitive_methods.get(&TypeMethodKey::Generic)?.get(method_name)?; global_methods.find_matching_method(typ, self) } } @@ -1682,7 +1693,7 @@ impl NodeInterner { /// Looks up a given method name on the given primitive type. pub fn lookup_primitive_method(&self, typ: &Type, method_name: &str) -> Option { let key = get_type_method_key(typ)?; - let methods = self.primitive_methods.get(&(key, method_name.to_owned()))?; + let methods = self.primitive_methods.get(&key)?.get(method_name)?; self.find_matching_method(typ, Some(methods), method_name) } @@ -1803,6 +1814,11 @@ impl NodeInterner { self.prefix_operator_traits.insert(operator, trait_id); } + pub fn is_operator_trait(&self, trait_id: TraitId) -> bool { + self.infix_operator_traits.values().any(|id| *id == trait_id) + || self.prefix_operator_traits.values().any(|id| *id == trait_id) + } + /// This function is needed when creating a NodeInterner for testing so that calls /// to `get_operator_trait` do not panic when the stdlib isn't present. #[cfg(test)] @@ -2017,7 +2033,7 @@ impl Methods { } /// Iterate through each method, starting with the direct methods - fn iter(&self) -> impl Iterator + '_ { + pub fn iter(&self) -> impl Iterator + '_ { self.direct.iter().copied().chain(self.trait_impl_methods.iter().copied()) } diff --git a/compiler/noirc_frontend/src/parser/errors.rs b/compiler/noirc_frontend/src/parser/errors.rs index 36d3ce5898c..b2de3b0794b 100644 --- a/compiler/noirc_frontend/src/parser/errors.rs +++ b/compiler/noirc_frontend/src/parser/errors.rs @@ -16,6 +16,8 @@ pub enum ParserErrorReason { ExpectedFieldName(Token), #[error("expected a pattern but found a type - {0}")] ExpectedPatternButFoundType(Token), + #[error("expected an identifier after .")] + ExpectedIdentifierAfterDot, #[error("expected an identifier after ::")] ExpectedIdentifierAfterColons, #[error("Expected a ; separating these two statements")] diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 5a97d66df9a..876fc454565 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -847,6 +847,9 @@ where ArrayIndex(Expression), Cast(UnresolvedType), MemberAccess(UnaryRhsMemberAccess), + /// This is to allow `foo.` (no identifier afterwards) to be parsed as `foo` + /// and produce an error, rather than just erroring (for LSP). + JustADot, } // `(arg1, ..., argN)` in `my_func(arg1, ..., argN)` @@ -890,7 +893,16 @@ where }) .labelled(ParsingRuleLabel::FieldAccess); - let rhs = choice((call_rhs, array_rhs, cast_rhs, member_rhs)); + let just_a_dot = + just(Token::Dot).map(|_| UnaryRhs::JustADot).validate(|value, span, emit_error| { + emit_error(ParserError::with_reason( + ParserErrorReason::ExpectedIdentifierAfterDot, + span, + )); + value + }); + + let rhs = choice((call_rhs, array_rhs, cast_rhs, member_rhs, just_a_dot)); foldl_with_span( atom(expr_parser, expr_no_constructors, statement, allow_constructors), @@ -904,6 +916,7 @@ where UnaryRhs::MemberAccess(field) => { Expression::member_access_or_method_call(lhs, field, span) } + UnaryRhs::JustADot => lhs, }, ) } @@ -1673,4 +1686,27 @@ mod test { let block_expr = block_expr.expect("Failed to parse module"); assert_eq!(block_expr.statements.len(), 2); } + + #[test] + fn test_parses_member_access_without_member_name() { + let src = "{ foo. }"; + + let (Some(block_expression), errors) = parse_recover(block(fresh_statement()), src) else { + panic!("Expected to be able to parse a block expression"); + }; + + assert_eq!(errors.len(), 1); + assert_eq!(errors[0].message, "expected an identifier after ."); + + let statement = &block_expression.statements[0]; + let StatementKind::Expression(expr) = &statement.kind else { + panic!("Expected an expression statement"); + }; + + let ExpressionKind::Variable(var) = &expr.kind else { + panic!("Expected a variable expression"); + }; + + assert_eq!(var.to_string(), "foo"); + } } diff --git a/compiler/noirc_frontend/src/resolve_locations.rs b/compiler/noirc_frontend/src/resolve_locations.rs index efb430b75eb..61ff6baf919 100644 --- a/compiler/noirc_frontend/src/resolve_locations.rs +++ b/compiler/noirc_frontend/src/resolve_locations.rs @@ -31,6 +31,12 @@ impl NodeInterner { location_candidate.map(|(index, _location)| *index) } + /// Returns the Type of the expression that exists at the given location. + pub fn type_at_location(&self, location: Location) -> Option { + let index = self.find_location_index(location)?; + Some(self.id_type(index)) + } + /// Returns the [Location] of the definition of the given Ident found at [Span] of the given [FileId]. /// Returns [None] when definition is not found. pub fn get_definition_location_from( diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index 48616c0f52d..028bbcd6b28 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -31,7 +31,7 @@ use noirc_frontend::{ node_interner::{FuncId, GlobalId, ReferenceId, TraitId, TypeAliasId}, parser::{Item, ItemKind}, token::Keyword, - ParsedModule, Type, + ParsedModule, StructType, Type, }; use strum::IntoEnumIterator; @@ -55,13 +55,22 @@ enum ModuleCompletionKind { /// When suggest a function as a result of completion, whether to autocomplete its name or its name and parameters. #[derive(Clone, Copy, PartialEq, Eq, Debug)] -enum FunctionCompleteKind { +enum FunctionCompletionKind { // Only complete a function's name. This is used in use statement. Name, // Complete a function's name and parameters (as a snippet). This is used in regular code. NameAndParameters, } +/// Is there a requirement for suggesting functions? +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +enum FunctionKind<'a> { + /// No requirement: any function is okay to suggest. + Any, + /// Only show functions that have the given self type. + SelfType(&'a Type), +} + /// When requesting completions, whether to list all items or just types. /// For example, when writing `let x: S` we only want to suggest types at this /// point (modules too, because they might include types too). @@ -173,7 +182,16 @@ impl<'a> NodeFinder<'a> { if self.completion_items.is_empty() { None } else { - Some(CompletionResponse::Array(std::mem::take(&mut self.completion_items))) + let mut items = std::mem::take(&mut self.completion_items); + + // Show items that start with underscore last in the list + for item in items.iter_mut() { + if item.label.starts_with('_') { + item.sort_text = Some(underscore_sort_text()); + } + } + + Some(CompletionResponse::Array(items)) } } @@ -459,6 +477,19 @@ impl<'a> NodeFinder<'a> { } fn find_in_expression(&mut self, expression: &Expression) { + // "foo." (no identifier afterwards) is parsed as the expression on the left hand-side of the dot. + // Here we check if there's a dot at the completion position, and if the expression + // ends right before the dot. If so, it means we want to complete the expression's type fields and methods. + if self.byte == Some(b'.') && expression.span.end() as usize == self.byte_index - 1 { + let location = Location::new(expression.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); + let prefix = ""; + self.complete_type_fields_and_methods(&typ, prefix); + return; + } + } + match &expression.kind { noirc_frontend::ast::ExpressionKind::Literal(literal) => self.find_in_literal(literal), noirc_frontend::ast::ExpressionKind::Block(block_expression) => { @@ -572,6 +603,19 @@ impl<'a> NodeFinder<'a> { &mut self, member_access_expression: &MemberAccessExpression, ) { + let ident = &member_access_expression.rhs; + + if self.byte_index == ident.span().end() as usize { + // Assuming member_access_expression is of the form `foo.bar`, we are right after `bar` + let location = Location::new(member_access_expression.lhs.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); + let prefix = ident.to_string(); + self.complete_type_fields_and_methods(&typ, &prefix); + return; + } + } + self.find_in_expression(&member_access_expression.lhs); } @@ -707,19 +751,50 @@ impl<'a> NodeFinder<'a> { } let is_single_segment = !after_colons && idents.is_empty() && path.kind == PathKind::Plain; + let module_id; - let module_id = - if idents.is_empty() { Some(self.module_id) } else { self.resolve_module(idents) }; - let Some(module_id) = module_id else { - return; - }; + if idents.is_empty() { + module_id = self.module_id; + } else { + let Some(module_def_id) = self.resolve_path(idents) else { + return; + }; + + match module_def_id { + ModuleDefId::ModuleId(id) => module_id = id, + ModuleDefId::TypeId(struct_id) => { + let struct_type = self.interner.get_struct(struct_id); + self.complete_type_methods( + &Type::Struct(struct_type, vec![]), + &prefix, + FunctionKind::Any, + ); + return; + } + ModuleDefId::FunctionId(_) => { + // There's nothing inside a function + return; + } + ModuleDefId::TypeAliasId(type_alias_id) => { + let type_alias = self.interner.get_type_alias(type_alias_id); + let type_alias = type_alias.borrow(); + self.complete_type_methods(&type_alias.typ, &prefix, FunctionKind::Any); + return; + } + ModuleDefId::TraitId(_) => { + // For now we don't suggest trait methods + return; + } + ModuleDefId::GlobalId(_) => return, + } + } let module_completion_kind = if after_colons { ModuleCompletionKind::DirectChildren } else { ModuleCompletionKind::AllVisibleItems }; - let function_completion_kind = FunctionCompleteKind::NameAndParameters; + let function_completion_kind = FunctionCompletionKind::NameAndParameters; self.complete_in_module( module_id, @@ -827,7 +902,7 @@ impl<'a> NodeFinder<'a> { } let module_completion_kind = ModuleCompletionKind::DirectChildren; - let function_completion_kind = FunctionCompleteKind::Name; + let function_completion_kind = FunctionCompletionKind::Name; let requested_items = RequestedItems::AnyItems; if after_colons { @@ -913,6 +988,78 @@ impl<'a> NodeFinder<'a> { }; } + fn complete_type_fields_and_methods(&mut self, typ: &Type, prefix: &str) { + match typ { + Type::Struct(struct_type, generics) => { + self.complete_struct_fields(&struct_type.borrow(), generics, prefix); + } + Type::MutableReference(typ) => { + return self.complete_type_fields_and_methods(typ, prefix); + } + Type::Alias(type_alias, _) => { + let type_alias = type_alias.borrow(); + return self.complete_type_fields_and_methods(&type_alias.typ, prefix); + } + Type::FieldElement + | Type::Array(_, _) + | Type::Slice(_) + | Type::Integer(_, _) + | Type::Bool + | Type::String(_) + | Type::FmtString(_, _) + | Type::Unit + | Type::Tuple(_) + | Type::TypeVariable(_, _) + | Type::TraitAsType(_, _, _) + | Type::NamedGeneric(_, _, _) + | Type::Function(_, _, _) + | Type::Forall(_, _) + | Type::Constant(_) + | Type::Quoted(_) + | Type::InfixExpr(_, _, _) + | Type::Error => (), + } + + self.complete_type_methods(typ, prefix, FunctionKind::SelfType(typ)); + } + + fn complete_type_methods(&mut self, typ: &Type, prefix: &str, function_kind: FunctionKind) { + let Some(methods_by_name) = self.interner.get_type_methods(typ) else { + return; + }; + + for (name, methods) in methods_by_name { + for func_id in methods.iter() { + if name_matches(name, prefix) { + if let Some(completion_item) = self.function_completion_item( + func_id, + FunctionCompletionKind::NameAndParameters, + function_kind, + ) { + self.completion_items.push(completion_item); + } + } + } + } + } + + fn complete_struct_fields( + &mut self, + struct_type: &StructType, + generics: &[Type], + prefix: &str, + ) { + for (name, typ) in struct_type.get_fields(generics) { + if name_matches(&name, prefix) { + self.completion_items.push(simple_completion_item( + name, + CompletionItemKind::FIELD, + Some(typ.to_string()), + )); + } + } + } + #[allow(clippy::too_many_arguments)] fn complete_in_module( &mut self, @@ -921,7 +1068,7 @@ impl<'a> NodeFinder<'a> { path_kind: PathKind, at_root: bool, module_completion_kind: ModuleCompletionKind, - function_completion_kind: FunctionCompleteKind, + function_completion_kind: FunctionCompletionKind, requested_items: RequestedItems, ) { let def_map = &self.def_maps[&module_id.krate]; @@ -951,6 +1098,8 @@ impl<'a> NodeFinder<'a> { } } + let function_kind = FunctionKind::Any; + let items = match module_completion_kind { ModuleCompletionKind::DirectChildren => module_data.definitions(), ModuleCompletionKind::AllVisibleItems => module_data.scope(), @@ -966,6 +1115,7 @@ impl<'a> NodeFinder<'a> { module_def_id, name.clone(), function_completion_kind, + function_kind, requested_items, ) { self.completion_items.push(completion_item); @@ -977,6 +1127,7 @@ impl<'a> NodeFinder<'a> { module_def_id, name.clone(), function_completion_kind, + function_kind, requested_items, ) { self.completion_items.push(completion_item); @@ -1015,7 +1166,8 @@ impl<'a> NodeFinder<'a> { &self, module_def_id: ModuleDefId, name: String, - function_completion_kind: FunctionCompleteKind, + function_completion_kind: FunctionCompletionKind, + function_kind: FunctionKind, requested_items: RequestedItems, ) -> Option { match requested_items { @@ -1029,64 +1181,142 @@ impl<'a> NodeFinder<'a> { RequestedItems::AnyItems => (), } - let completion_item = match module_def_id { - ModuleDefId::ModuleId(_) => module_completion_item(name), + match module_def_id { + ModuleDefId::ModuleId(_) => Some(module_completion_item(name)), ModuleDefId::FunctionId(func_id) => { - self.function_completion_item(func_id, function_completion_kind) + self.function_completion_item(func_id, function_completion_kind, function_kind) } - ModuleDefId::TypeId(struct_id) => self.struct_completion_item(struct_id), + ModuleDefId::TypeId(struct_id) => Some(self.struct_completion_item(struct_id)), ModuleDefId::TypeAliasId(type_alias_id) => { - self.type_alias_completion_item(type_alias_id) + Some(self.type_alias_completion_item(type_alias_id)) } - ModuleDefId::TraitId(trait_id) => self.trait_completion_item(trait_id), - ModuleDefId::GlobalId(global_id) => self.global_completion_item(global_id), - }; - Some(completion_item) + ModuleDefId::TraitId(trait_id) => Some(self.trait_completion_item(trait_id)), + ModuleDefId::GlobalId(global_id) => Some(self.global_completion_item(global_id)), + } } fn function_completion_item( &self, func_id: FuncId, - function_completion_kind: FunctionCompleteKind, - ) -> CompletionItem { + function_completion_kind: FunctionCompletionKind, + function_kind: FunctionKind, + ) -> Option { let func_meta = self.interner.function_meta(&func_id); - let name = self.interner.function_name(&func_id).to_string(); + let name = &self.interner.function_name(&func_id).to_string(); + + let func_self_type = if let Some((pattern, typ, _)) = func_meta.parameters.0.get(0) { + if self.hir_pattern_is_self_type(pattern) { + if let Type::MutableReference(mut_typ) = typ { + let typ: &Type = mut_typ; + Some(typ) + } else { + Some(typ) + } + } else { + None + } + } else { + None + }; - let mut typ = &func_meta.typ; - if let Type::Forall(_, typ_) = typ { - typ = typ_; + match function_kind { + FunctionKind::Any => (), + FunctionKind::SelfType(mut self_type) => { + if let Some(func_self_type) = func_self_type { + if matches!(self_type, Type::Integer(..)) + || matches!(self_type, Type::FieldElement) + { + // Check that the pattern type is the same as self type. + // We do this because some types (only Field and integer types) + // have their methods in the same HashMap. + + if let Type::MutableReference(mut_typ) = self_type { + self_type = mut_typ; + } + + if self_type != func_self_type { + return None; + } + } + } else { + return None; + } + } } - let description = typ.to_string(); - let description = description.strip_suffix(" -> ()").unwrap_or(&description).to_string(); - match function_completion_kind { - FunctionCompleteKind::Name => { + let is_operator = if let Some(trait_impl_id) = &func_meta.trait_impl { + let trait_impl = self.interner.get_trait_implementation(*trait_impl_id); + let trait_impl = trait_impl.borrow(); + self.interner.is_operator_trait(trait_impl.trait_id) + } else { + false + }; + let description = func_meta_type_to_string(func_meta, func_self_type.is_some()); + + let completion_item = match function_completion_kind { + FunctionCompletionKind::Name => { simple_completion_item(name, CompletionItemKind::FUNCTION, Some(description)) } - FunctionCompleteKind::NameAndParameters => { - let label = format!("{}(…)", name); + FunctionCompletionKind::NameAndParameters => { let kind = CompletionItemKind::FUNCTION; - let insert_text = self.compute_function_insert_text(func_meta, &name); + let insert_text = self.compute_function_insert_text(func_meta, name, function_kind); + let label = if insert_text.ends_with("()") { + format!("{}()", name) + } else { + format!("{}(…)", name) + }; snippet_completion_item(label, kind, insert_text, Some(description)) } - } + }; + + let completion_item = if is_operator { + completion_item_with_sort_text(completion_item, operator_sort_text()) + } else if function_kind == FunctionKind::Any && name == "new" { + completion_item_with_sort_text(completion_item, new_sort_text()) + } else if function_kind == FunctionKind::Any && func_self_type.is_some() { + completion_item_with_sort_text(completion_item, self_mismatch_sort_text()) + } else { + completion_item + }; + + Some(completion_item) } - fn compute_function_insert_text(&self, func_meta: &FuncMeta, name: &str) -> String { + fn compute_function_insert_text( + &self, + func_meta: &FuncMeta, + name: &str, + function_kind: FunctionKind, + ) -> String { let mut text = String::new(); text.push_str(name); text.push('('); - for (index, (pattern, _, _)) in func_meta.parameters.0.iter().enumerate() { - if index > 0 { + + let mut index = 1; + for (pattern, _, _) in &func_meta.parameters.0 { + if index == 1 { + match function_kind { + FunctionKind::SelfType(_) => { + if self.hir_pattern_is_self_type(pattern) { + continue; + } + } + FunctionKind::Any => (), + } + } + + if index > 1 { text.push_str(", "); } text.push_str("${"); - text.push_str(&(index + 1).to_string()); + text.push_str(&index.to_string()); text.push(':'); self.hir_pattern_to_argument(pattern, &mut text); text.push('}'); + + index += 1; } text.push(')'); text @@ -1102,6 +1332,17 @@ impl<'a> NodeFinder<'a> { } } + fn hir_pattern_is_self_type(&self, pattern: &HirPattern) -> bool { + match pattern { + HirPattern::Identifier(hir_ident) => { + let name = self.interner.definition_name(hir_ident.id); + name == "self" || name == "_self" + } + HirPattern::Mutable(pattern, _) => self.hir_pattern_is_self_type(pattern), + HirPattern::Tuple(_, _) | HirPattern::Struct(_, _, _) => false, + } + } + fn struct_completion_item(&self, struct_id: StructId) -> CompletionItem { let struct_type = self.interner.get_struct(struct_id); let struct_type = struct_type.borrow(); @@ -1145,14 +1386,25 @@ impl<'a> NodeFinder<'a> { } fn resolve_path(&self, segments: Vec) -> Option { + let last_segment = segments.last().unwrap().clone(); + let path_segments = segments.into_iter().map(PathSegment::from).collect(); let path = Path { segments: path_segments, kind: PathKind::Plain, span: Span::default() }; let path_resolver = StandardPathResolver::new(self.root_module_id); - match path_resolver.resolve(self.def_maps, path, &mut None) { - Ok(path_resolution) => Some(path_resolution.module_def_id), - Err(_) => None, + if let Ok(path_resolution) = path_resolver.resolve(self.def_maps, path, &mut None) { + return Some(path_resolution.module_def_id); } + + // If we can't resolve a path trough lookup, let's see if the last segment is bound to a type + let location = Location::new(last_segment.span(), self.file); + if let Some(reference_id) = self.interner.find_referenced(location) { + if let Some(id) = module_def_id_from_reference_id(reference_id) { + return Some(id); + } + } + + None } fn builtin_functions_completion(&mut self, prefix: &str) { @@ -1236,7 +1488,7 @@ fn simple_completion_item( documentation: None, deprecated: None, preselect: None, - sort_text: None, + sort_text: Some(default_sort_text()), filter_text: None, insert_text: None, insert_text_format: None, @@ -1266,7 +1518,7 @@ fn snippet_completion_item( documentation: None, deprecated: None, preselect: None, - sort_text: None, + sort_text: Some(default_sort_text()), filter_text: None, insert_text_mode: None, text_edit: None, @@ -1278,6 +1530,98 @@ fn snippet_completion_item( } } +fn completion_item_with_sort_text( + completion_item: CompletionItem, + sort_text: String, +) -> CompletionItem { + CompletionItem { sort_text: Some(sort_text), ..completion_item } +} + +fn module_def_id_from_reference_id(reference_id: ReferenceId) -> Option { + match reference_id { + ReferenceId::Module(module_id) => Some(ModuleDefId::ModuleId(module_id)), + ReferenceId::Struct(struct_id) => Some(ModuleDefId::TypeId(struct_id)), + ReferenceId::Trait(trait_id) => Some(ModuleDefId::TraitId(trait_id)), + ReferenceId::Function(func_id) => Some(ModuleDefId::FunctionId(func_id)), + ReferenceId::Alias(type_alias_id) => Some(ModuleDefId::TypeAliasId(type_alias_id)), + ReferenceId::StructMember(_, _) + | ReferenceId::Global(_) + | ReferenceId::Local(_) + | ReferenceId::Reference(_, _) => None, + } +} + +fn func_meta_type_to_string(func_meta: &FuncMeta, has_self_type: bool) -> String { + let mut typ = &func_meta.typ; + if let Type::Forall(_, typ_) = typ { + typ = typ_; + } + + if let Type::Function(args, ret, _env) = typ { + let mut string = String::new(); + string.push_str("fn("); + for (index, arg) in args.iter().enumerate() { + if index > 0 { + string.push_str(", "); + } + if index == 0 && has_self_type { + type_to_self_string(arg, &mut string); + } else { + string.push_str(&arg.to_string()); + } + } + string.push(')'); + + let ret: &Type = ret; + if let Type::Unit = ret { + // Nothing + } else { + string.push_str(" -> "); + string.push_str(&ret.to_string()); + } + string + } else { + typ.to_string() + } +} + +fn type_to_self_string(typ: &Type, string: &mut String) { + if let Type::MutableReference(..) = typ { + string.push_str("&mut self"); + } else { + string.push_str("self"); + } +} + +/// Sort text for "new" methods: we want these to show up before anything else, +/// if we are completing at something like `Foo::` +fn new_sort_text() -> String { + "3".to_string() +} + +/// This is the default sort text. +fn default_sort_text() -> String { + "5".to_string() +} + +/// When completing something like `Foo::`, we want to show methods that take +/// self after the other ones. +fn self_mismatch_sort_text() -> String { + "7".to_string() +} + +/// We want to show operator methods last. +fn operator_sort_text() -> String { + "8".to_string() +} + +/// If a name begins with underscore it's likely something that's meant to +/// be private (but visibility doesn't exist everywhere yet, so for now +/// we assume that) +fn underscore_sort_text() -> String { + "9".to_string() +} + #[cfg(test)] mod completion_tests { use crate::{notifications::on_did_open_text_document, test_utils}; @@ -1664,6 +2008,27 @@ mod completion_tests { .await; } + #[test] + async fn test_complete_function_without_arguments() { + let src = r#" + fn hello() { } + + fn main() { + h>|< + } + "#; + assert_completion( + src, + vec![snippet_completion_item( + "hello()", + CompletionItemKind::FUNCTION, + "hello()", + Some("fn()".to_string()), + )], + ) + .await; + } + #[test] async fn test_complete_function() { let src = r#" @@ -2201,4 +2566,315 @@ mod completion_tests { ) .await; } + + #[test] + async fn test_suggests_struct_field_after_dot_and_letter() { + let src = r#" + struct Some { + property: i32, + } + + fn foo(s: Some) { + s.p>|< + } + "#; + assert_completion( + src, + vec![simple_completion_item( + "property", + CompletionItemKind::FIELD, + Some("i32".to_string()), + )], + ) + .await; + } + + #[test] + async fn test_suggests_struct_field_after_dot_and_letter_for_generic_type() { + let src = r#" + struct Some { + property: T, + } + + fn foo(s: Some) { + s.p>|< + } + "#; + assert_completion( + src, + vec![simple_completion_item( + "property", + CompletionItemKind::FIELD, + Some("i32".to_string()), + )], + ) + .await; + } + + #[test] + async fn test_suggests_struct_field_after_dot_followed_by_brace() { + let src = r#" + struct Some { + property: i32, + } + + fn foo(s: Some) { + s.>|< + } + "#; + assert_completion( + src, + vec![simple_completion_item( + "property", + CompletionItemKind::FIELD, + Some("i32".to_string()), + )], + ) + .await; + } + + #[test] + async fn test_suggests_struct_field_after_dot_chain() { + let src = r#" + struct Some { + property: Other, + } + + struct Other { + bar: i32, + } + + fn foo(some: Some) { + some.property.>|< + } + "#; + assert_completion( + src, + vec![simple_completion_item("bar", CompletionItemKind::FIELD, Some("i32".to_string()))], + ) + .await; + } + + #[test] + async fn test_suggests_struct_impl_method() { + let src = r#" + struct Some { + } + + impl Some { + fn foobar(self, x: i32) {} + fn foobar2(&mut self, x: i32) {} + fn foobar3(y: i32) {} + } + + fn foo(some: Some) { + some.f>|< + } + "#; + assert_completion( + src, + vec![ + snippet_completion_item( + "foobar(…)", + CompletionItemKind::FUNCTION, + "foobar(${1:x})", + Some("fn(self, i32)".to_string()), + ), + snippet_completion_item( + "foobar2(…)", + CompletionItemKind::FUNCTION, + "foobar2(${1:x})", + Some("fn(&mut self, i32)".to_string()), + ), + ], + ) + .await; + } + + #[test] + async fn test_suggests_struct_trait_impl_method() { + let src = r#" + struct Some { + } + + trait SomeTrait { + fn foobar(self, x: i32); + fn foobar2(y: i32); + } + + impl SomeTrait for Some { + fn foobar(self, x: i32) {} + fn foobar2(y: i32) {} + } + + fn foo(some: Some) { + some.f>|< + } + "#; + assert_completion( + src, + vec![snippet_completion_item( + "foobar(…)", + CompletionItemKind::FUNCTION, + "foobar(${1:x})", + Some("fn(self, i32)".to_string()), + )], + ) + .await; + } + + #[test] + async fn test_suggests_primitive_trait_impl_method() { + let src = r#" + trait SomeTrait { + fn foobar(self, x: i32); + fn foobar2(y: i32); + } + + impl SomeTrait for Field { + fn foobar(self, x: i32) {} + fn foobar2(y: i32) {} + } + + fn foo(field: Field) { + field.f>|< + } + "#; + assert_completion( + src, + vec![snippet_completion_item( + "foobar(…)", + CompletionItemKind::FUNCTION, + "foobar(${1:x})", + Some("fn(self, i32)".to_string()), + )], + ) + .await; + } + + #[test] + async fn test_suggests_struct_methods_after_colons() { + let src = r#" + struct Some { + } + + impl Some { + fn foobar(self, x: i32) {} + fn foobar2(&mut self, x: i32) {} + fn foobar3(y: i32) {} + } + + fn foo() { + Some::>|< + } + "#; + assert_completion( + src, + vec![ + completion_item_with_sort_text( + snippet_completion_item( + "foobar(…)", + CompletionItemKind::FUNCTION, + "foobar(${1:self}, ${2:x})", + Some("fn(self, i32)".to_string()), + ), + self_mismatch_sort_text(), + ), + completion_item_with_sort_text( + snippet_completion_item( + "foobar2(…)", + CompletionItemKind::FUNCTION, + "foobar2(${1:self}, ${2:x})", + Some("fn(&mut self, i32)".to_string()), + ), + self_mismatch_sort_text(), + ), + snippet_completion_item( + "foobar3(…)", + CompletionItemKind::FUNCTION, + "foobar3(${1:y})", + Some("fn(i32)".to_string()), + ), + ], + ) + .await; + } + + #[test] + async fn test_suggests_struct_behind_alias_methods_after_dot() { + let src = r#" + struct Some { + } + + type Alias = Some; + + impl Some { + fn foobar(self, x: i32) {} + } + + fn foo(some: Alias) { + some.>|< + } + "#; + assert_completion( + src, + vec![snippet_completion_item( + "foobar(…)", + CompletionItemKind::FUNCTION, + "foobar(${1:x})", + Some("fn(self, i32)".to_string()), + )], + ) + .await; + } + + #[test] + async fn test_suggests_struct_behind_alias_methods_after_colons() { + let src = r#" + struct Some { + } + + type Alias = Some; + + impl Some { + fn foobar(self, x: i32) {} + fn foobar2(&mut self, x: i32) {} + fn foobar3(y: i32) {} + } + + fn foo() { + Alias::>|< + } + "#; + assert_completion( + src, + vec![ + completion_item_with_sort_text( + snippet_completion_item( + "foobar(…)", + CompletionItemKind::FUNCTION, + "foobar(${1:self}, ${2:x})", + Some("fn(self, i32)".to_string()), + ), + self_mismatch_sort_text(), + ), + completion_item_with_sort_text( + snippet_completion_item( + "foobar2(…)", + CompletionItemKind::FUNCTION, + "foobar2(${1:self}, ${2:x})", + Some("fn(&mut self, i32)".to_string()), + ), + self_mismatch_sort_text(), + ), + snippet_completion_item( + "foobar3(…)", + CompletionItemKind::FUNCTION, + "foobar3(${1:y})", + Some("fn(i32)".to_string()), + ), + ], + ) + .await; + } } diff --git a/tooling/lsp/src/requests/mod.rs b/tooling/lsp/src/requests/mod.rs index e138f839600..b6c26110e59 100644 --- a/tooling/lsp/src/requests/mod.rs +++ b/tooling/lsp/src/requests/mod.rs @@ -234,7 +234,7 @@ pub(crate) fn on_initialize( )), completion_provider: Some(lsp_types::OneOf::Right(lsp_types::CompletionOptions { resolve_provider: None, - trigger_characters: Some(vec![":".to_string()]), + trigger_characters: Some(vec![".".to_string(), ":".to_string()]), all_commit_characters: None, work_done_progress_options: WorkDoneProgressOptions { work_done_progress: None,