diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index 41327c988d2..d7e6b8b0a3d 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -591,7 +591,6 @@ impl<'a> FunctionContext<'a> { } self.codegen_intrinsic_call_checks(function, &arguments, call.location); - Ok(self.insert_call(function, arguments, &call.return_type, call.location)) } diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index d6eddeffc07..0806a8eb757 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -433,7 +433,10 @@ pub(crate) fn check_methods_signatures( let impl_method_generic_count = impl_method.typ.generic_count() - trait_impl_generic_count; - let trait_method_generic_count = trait_method.generics.len(); + + // We subtract 1 here to account for the implicit generic `Self` type that is on all + // traits (and thus trait methods) but is not required (or allowed) for users to specify. + let trait_method_generic_count = trait_method.generics().len() - 1; if impl_method_generic_count != trait_method_generic_count { let error = DefCollectorErrorKind::MismatchTraitImplementationNumGenerics { @@ -447,9 +450,9 @@ pub(crate) fn check_methods_signatures( } if let Type::Function(impl_params, _, _) = impl_function_type.0 { - if trait_method.arguments.len() == impl_params.len() { + if trait_method.arguments().len() == impl_params.len() { // Check the parameters of the impl method against the parameters of the trait method - let args = trait_method.arguments.iter(); + let args = trait_method.arguments().iter(); let args_and_params = args.zip(&impl_params).zip(&impl_method.parameters.0); for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in @@ -468,7 +471,7 @@ pub(crate) fn check_methods_signatures( } else { let error = DefCollectorErrorKind::MismatchTraitImplementationNumParameters { actual_num_parameters: impl_method.parameters.0.len(), - expected_num_parameters: trait_method.arguments.len(), + expected_num_parameters: trait_method.arguments().len(), trait_name: the_trait.name.to_string(), method_name: func_name.to_string(), span: impl_method.location.span, @@ -481,11 +484,12 @@ pub(crate) fn check_methods_signatures( let resolved_return_type = resolver.resolve_type(impl_method.return_type.get_type().into_owned()); - trait_method.return_type.unify(&resolved_return_type, &mut typecheck_errors, || { + // TODO: This is not right since it may bind generic return types + trait_method.return_type().unify(&resolved_return_type, &mut typecheck_errors, || { let ret_type_span = impl_method.return_type.get_type().span; let expr_span = ret_type_span.expect("return type must always have a span"); - let expected_typ = trait_method.return_type.to_string(); + let expected_typ = trait_method.return_type().to_string(); let expr_typ = impl_method.return_type().to_string(); TypeCheckError::TypeMismatch { expr_typ, expected_typ, expr_span } }); diff --git a/compiler/noirc_frontend/src/hir/resolution/traits.rs b/compiler/noirc_frontend/src/hir/resolution/traits.rs index 702e96362a6..7a6cbccb081 100644 --- a/compiler/noirc_frontend/src/hir/resolution/traits.rs +++ b/compiler/noirc_frontend/src/hir/resolution/traits.rs @@ -18,7 +18,7 @@ use crate::{ }, hir_def::traits::{Trait, TraitConstant, TraitFunction, TraitImpl, TraitType}, node_interner::{FuncId, NodeInterner, TraitId}, - Path, Shared, TraitItem, Type, TypeVariableKind, + Path, Shared, TraitItem, Type, TypeBinding, TypeVariableKind, }; use super::{ @@ -111,8 +111,17 @@ fn resolve_trait_methods( resolver.set_self_type(Some(self_type)); let arguments = vecmap(parameters, |param| resolver.resolve_type(param.1.clone())); - let resolved_return_type = resolver.resolve_type(return_type.get_type().into_owned()); - let generics = resolver.get_generics().to_vec(); + let return_type = resolver.resolve_type(return_type.get_type().into_owned()); + + let mut generics = vecmap(resolver.get_generics(), |(_, type_var, _)| match &*type_var + .borrow() + { + TypeBinding::Unbound(id) => (*id, type_var.clone()), + TypeBinding::Bound(binding) => unreachable!("Trait generic was bound to {binding}"), + }); + + // Ensure the trait is generic over the Self type as well + generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar)); let name = name.clone(); let span: Span = name.span(); @@ -128,11 +137,13 @@ fn resolve_trait_methods( None }; + let no_environment = Box::new(Type::Unit); + let function_type = Type::Function(arguments, Box::new(return_type), no_environment); + let typ = Type::Forall(generics, Box::new(function_type)); + let f = TraitFunction { name, - generics, - arguments, - return_type: resolved_return_type, + typ, span, default_impl, default_impl_file_id: unresolved_trait.file_id, diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 1fe1eaa899c..9a64ab55196 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -289,14 +289,7 @@ impl<'interner> TypeChecker<'interner> { } HirExpression::TraitMethodReference(method) => { let the_trait = self.interner.get_trait(method.trait_id); - let method = &the_trait.methods[method.method_index]; - - let typ = Type::Function( - method.arguments.clone(), - Box::new(method.return_type.clone()), - Box::new(Type::Unit), - ); - + let typ = &the_trait.methods[method.method_index].typ; let (typ, bindings) = typ.instantiate(self.interner); self.interner.store_instantiation_bindings(*expr_id, bindings); typ @@ -546,7 +539,7 @@ impl<'interner> TypeChecker<'interner> { HirMethodReference::TraitMethodId(method) => { let the_trait = self.interner.get_trait(method.trait_id); let method = &the_trait.methods[method.method_index]; - (method.get_type(), method.arguments.len()) + (method.typ.clone(), method.arguments().len()) } }; diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 5f0bf49ca0f..e6c46a46073 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -1,5 +1,3 @@ -use std::rc::Rc; - use crate::{ graph::CrateId, node_interner::{FuncId, TraitId, TraitMethodId}, @@ -11,9 +9,7 @@ use noirc_errors::Span; #[derive(Clone, Debug, PartialEq, Eq)] pub struct TraitFunction { pub name: Ident, - pub generics: Vec<(Rc, TypeVariable, Span)>, - pub arguments: Vec, - pub return_type: Type, + pub typ: Type, pub span: Span, pub default_impl: Option>, pub default_impl_file_id: fm::FileId, @@ -145,12 +141,33 @@ impl std::fmt::Display for Trait { } impl TraitFunction { - pub fn get_type(&self) -> Type { - Type::Function( - self.arguments.clone(), - Box::new(self.return_type.clone()), - Box::new(Type::Unit), - ) - .generalize() + pub fn arguments(&self) -> &[Type] { + match &self.typ { + Type::Function(args, _, _) => args, + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(args, _, _) => args, + _ => unreachable!("Trait function does not have a function type"), + }, + _ => unreachable!("Trait function does not have a function type"), + } + } + + pub fn generics(&self) -> &[(TypeVariableId, TypeVariable)] { + match &self.typ { + Type::Function(..) => &[], + Type::Forall(generics, _) => generics, + _ => unreachable!("Trait function does not have a function type"), + } + } + + pub fn return_type(&self) -> &Type { + match &self.typ { + Type::Function(_, return_type, _) => return_type, + Type::Forall(_, typ) => match typ.as_ref() { + Type::Function(_, return_type, _) => return_type, + _ => unreachable!("Trait function does not have a function type"), + }, + _ => unreachable!("Trait function does not have a function type"), + } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index c7b3fcc499c..ff5e157cec4 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -75,7 +75,7 @@ pub enum Type { /// the environment should be `Unit` by default, /// for closures it should contain a `Tuple` type with the captured /// variable types. - Function(Vec, Box, Box), + Function(Vec, /*return_type:*/ Box, /*environment:*/ Box), /// &mut T MutableReference(Box), @@ -668,31 +668,12 @@ impl Type { } } - /// Takes a monomorphic type and generalizes it over each of the given type variables. - pub(crate) fn generalize_from_variables( - self, - type_vars: HashMap, - ) -> Type { - let polymorphic_type_vars = vecmap(type_vars, |type_var| type_var); - Type::Forall(polymorphic_type_vars, Box::new(self)) - } - /// Takes a monomorphic type and generalizes it over each of the type variables in the /// given type bindings, ignoring what each type variable is bound to in the TypeBindings. pub(crate) fn generalize_from_substitutions(self, type_bindings: TypeBindings) -> Type { let polymorphic_type_vars = vecmap(type_bindings, |(id, (type_var, _))| (id, type_var)); Type::Forall(polymorphic_type_vars, Box::new(self)) } - - /// Takes a monomorphic type and generalizes it over each type variable found within. - /// - /// Note that Noir's type system assumes any Type::Forall are only present at top-level, - /// and thus all type variable's within a type are free. - pub(crate) fn generalize(self) -> Type { - let mut type_variables = HashMap::new(); - self.find_all_unbound_type_variables(&mut type_variables); - self.generalize_from_variables(type_variables) - } } impl std::fmt::Display for Type { diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index e72c3a2a948..52ed0c746e1 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -886,7 +886,6 @@ impl<'interner> Monomorphizer<'interner> { let original_func = Box::new(self.expr(call.func)); let mut arguments = vecmap(&call.arguments, |id| self.expr(*id)); let hir_arguments = vecmap(&call.arguments, |id| self.interner.expression(id)); - let func: Box; let return_type = self.interner.id_type(id); let return_type = self.convert_type(&return_type); @@ -907,7 +906,8 @@ impl<'interner> Monomorphizer<'interner> { let func_type = self.interner.id_type(call.func); let func_type = self.convert_type(&func_type); let is_closure = self.is_function_closure(func_type); - if is_closure { + + let func = if is_closure { let local_id = self.next_local_id(); // store the function in a temporary variable before calling it @@ -929,14 +929,13 @@ impl<'interner> Monomorphizer<'interner> { typ: self.convert_type(&self.interner.id_type(call.func)), }); - func = Box::new(ast::Expression::ExtractTupleField( - Box::new(extracted_func.clone()), - 1usize, - )); - let env_argument = ast::Expression::ExtractTupleField(Box::new(extracted_func), 0usize); + let env_argument = + ast::Expression::ExtractTupleField(Box::new(extracted_func.clone()), 0usize); arguments.insert(0, env_argument); + + Box::new(ast::Expression::ExtractTupleField(Box::new(extracted_func), 1usize)) } else { - func = original_func.clone(); + original_func.clone() }; let call = self