Skip to content
1 change: 0 additions & 1 deletion compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,6 @@ impl<'a> FunctionContext<'a> {
}

self.codegen_intrinsic_call_checks(function, &arguments, call.location);

Ok(self.insert_call(function, arguments, &call.return_type, call.location))
}

Expand Down
20 changes: 12 additions & 8 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ pub(crate) fn check_methods_signatures(
let self_type = resolver.get_self_type().expect("trait impl must have a Self type");

// Temporarily bind the trait's Self type to self_type so we can type check
let _ = the_trait.self_type_typevar.borrow_mut().bind_to(self_type.clone(), the_trait.span);
the_trait.self_type_typevar.bind(self_type.clone());

for (file_id, func_id) in impl_methods {
let impl_method = resolver.interner.function_meta(func_id);
Expand All @@ -433,7 +433,10 @@ pub(crate) fn check_methods_signatures(

let impl_method_generic_count =
impl_method.typ.generic_count() - trait_impl_generic_count;
let trait_method_generic_count = trait_method.generics.len();

// We subtract 1 here to account for the implicit generic `Self` type that is on all
// traits (and thus trait methods) but is not required (or allowed) for users to specify.
let trait_method_generic_count = trait_method.generics().len() - 1;

if impl_method_generic_count != trait_method_generic_count {
let error = DefCollectorErrorKind::MismatchTraitImplementationNumGenerics {
Expand All @@ -447,9 +450,9 @@ pub(crate) fn check_methods_signatures(
}

if let Type::Function(impl_params, _, _) = impl_function_type.0 {
if trait_method.arguments.len() == impl_params.len() {
if trait_method.arguments().len() == impl_params.len() {
// Check the parameters of the impl method against the parameters of the trait method
let args = trait_method.arguments.iter();
let args = trait_method.arguments().iter();
let args_and_params = args.zip(&impl_params).zip(&impl_method.parameters.0);

for (parameter_index, ((expected, actual), (hir_pattern, _, _))) in
Expand All @@ -468,7 +471,7 @@ pub(crate) fn check_methods_signatures(
} else {
let error = DefCollectorErrorKind::MismatchTraitImplementationNumParameters {
actual_num_parameters: impl_method.parameters.0.len(),
expected_num_parameters: trait_method.arguments.len(),
expected_num_parameters: trait_method.arguments().len(),
trait_name: the_trait.name.to_string(),
method_name: func_name.to_string(),
span: impl_method.location.span,
Expand All @@ -481,11 +484,12 @@ pub(crate) fn check_methods_signatures(
let resolved_return_type =
resolver.resolve_type(impl_method.return_type.get_type().into_owned());

trait_method.return_type.unify(&resolved_return_type, &mut typecheck_errors, || {
// TODO: This is not right since it may bind generic return types
trait_method.return_type().unify(&resolved_return_type, &mut typecheck_errors, || {
let ret_type_span = impl_method.return_type.get_type().span;
let expr_span = ret_type_span.expect("return type must always have a span");

let expected_typ = trait_method.return_type.to_string();
let expected_typ = trait_method.return_type().to_string();
let expr_typ = impl_method.return_type().to_string();
TypeCheckError::TypeMismatch { expr_typ, expected_typ, expr_span }
});
Expand All @@ -494,5 +498,5 @@ pub(crate) fn check_methods_signatures(
}
}

the_trait.self_type_typevar.borrow_mut().unbind(the_trait.self_type_typevar_id);
the_trait.self_type_typevar.unbind(the_trait.self_type_typevar_id);
}
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/hir/resolution/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
def_map::{CrateDefMap, ModuleId},
},
node_interner::{FuncId, NodeInterner, TraitImplId},
Shared, Type, TypeBinding,
Type, TypeVariable,
};

use super::{path_resolver::StandardPathResolver, resolver::Resolver};
Expand All @@ -24,7 +24,7 @@ pub(crate) fn resolve_function_set(
mut unresolved_functions: UnresolvedFunctions,
self_type: Option<Type>,
trait_impl_id: Option<TraitImplId>,
impl_generics: Vec<(Rc<String>, Shared<TypeBinding>, Span)>,
impl_generics: Vec<(Rc<String>, TypeVariable, Span)>,
errors: &mut Vec<(CompilationError, FileId)>,
) -> Vec<(FileId, FuncId)> {
let file_id = unresolved_functions.file_id;
Expand Down
9 changes: 3 additions & 6 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ impl<'a> Resolver<'a> {
match length {
None => {
let id = self.interner.next_type_variable_id();
let typevar = Shared::new(TypeBinding::Unbound(id));
let typevar = TypeVariable::unbound(id);
new_variables.push((id, typevar.clone()));

// 'Named'Generic is a bit of a misnomer here, we want a type variable that
Expand Down Expand Up @@ -682,7 +682,7 @@ impl<'a> Resolver<'a> {
vecmap(generics, |generic| {
// Map the generic to a fresh type variable
let id = self.interner.next_type_variable_id();
let typevar = Shared::new(TypeBinding::Unbound(id));
let typevar = TypeVariable::unbound(id);
let span = generic.0.span();

// Check for name collisions of this generic
Expand Down Expand Up @@ -925,10 +925,7 @@ impl<'a> Resolver<'a> {
found.into_iter().collect()
}

fn find_numeric_generics_in_type(
typ: &Type,
found: &mut BTreeMap<String, Shared<TypeBinding>>,
) {
fn find_numeric_generics_in_type(typ: &Type, found: &mut BTreeMap<String, TypeVariable>) {
match typ {
Type::FieldElement
| Type::Integer(_, _)
Expand Down
23 changes: 17 additions & 6 deletions compiler/noirc_frontend/src/hir/resolution/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::{
},
hir_def::traits::{Trait, TraitConstant, TraitFunction, TraitImpl, TraitType},
node_interner::{FuncId, NodeInterner, TraitId},
Path, Shared, TraitItem, Type, TypeVariableKind,
Path, Shared, TraitItem, Type, TypeBinding, TypeVariableKind,
};

use super::{
Expand Down Expand Up @@ -111,8 +111,17 @@ fn resolve_trait_methods(
resolver.set_self_type(Some(self_type));

let arguments = vecmap(parameters, |param| resolver.resolve_type(param.1.clone()));
let resolved_return_type = resolver.resolve_type(return_type.get_type().into_owned());
let generics = resolver.get_generics().to_vec();
let return_type = resolver.resolve_type(return_type.get_type().into_owned());

let mut generics = vecmap(resolver.get_generics(), |(_, type_var, _)| match &*type_var
.borrow()
{
TypeBinding::Unbound(id) => (*id, type_var.clone()),
TypeBinding::Bound(binding) => unreachable!("Trait generic was bound to {binding}"),
});

// Ensure the trait is generic over the Self type as well
generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar));

let name = name.clone();
let span: Span = name.span();
Expand All @@ -128,11 +137,13 @@ fn resolve_trait_methods(
None
};

let no_environment = Box::new(Type::Unit);
let function_type = Type::Function(arguments, Box::new(return_type), no_environment);
let typ = Type::Forall(generics, Box::new(function_type));

let f = TraitFunction {
name,
generics,
arguments,
return_type: resolved_return_type,
typ,
span,
default_impl,
default_impl_file_id: unresolved_trait.file_id,
Expand Down
27 changes: 14 additions & 13 deletions compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
types::Type,
},
node_interner::{DefinitionKind, ExprId, FuncId, TraitId, TraitMethodId},
BinaryOpKind, Signedness, TypeBinding, TypeVariableKind, UnaryOp,
BinaryOpKind, Signedness, TypeBinding, TypeBindings, TypeVariableKind, UnaryOp,
};

use super::{errors::TypeCheckError, TypeChecker};
Expand Down Expand Up @@ -289,14 +289,7 @@ impl<'interner> TypeChecker<'interner> {
}
HirExpression::TraitMethodReference(method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let method = &the_trait.methods[method.method_index];

let typ = Type::Function(
method.arguments.clone(),
Box::new(method.return_type.clone()),
Box::new(Type::Unit),
);

let typ = &the_trait.methods[method.method_index].typ;
let (typ, bindings) = typ.instantiate(self.interner);
self.interner.store_instantiation_bindings(*expr_id, bindings);
typ
Expand Down Expand Up @@ -546,7 +539,7 @@ impl<'interner> TypeChecker<'interner> {
HirMethodReference::TraitMethodId(method) => {
let the_trait = self.interner.get_trait(method.trait_id);
let method = &the_trait.methods[method.method_index];
(method.get_type(), method.arguments.len())
(method.typ.clone(), method.arguments().len())
}
};

Expand Down Expand Up @@ -778,7 +771,11 @@ impl<'interner> TypeChecker<'interner> {
}));
}

if other.try_bind_to_polymorphic_int(int).is_ok() || other == &Type::Error {
let mut bindings = TypeBindings::new();
if other.try_bind_to_polymorphic_int(int, &mut bindings).is_ok()
|| other == &Type::Error
{
Type::apply_type_bindings(bindings);
Ok(Bool)
} else {
Err(TypeCheckError::TypeMismatchWithSource {
Expand Down Expand Up @@ -1009,7 +1006,7 @@ impl<'interner> TypeChecker<'interner> {
let env_type = self.interner.next_type_variable();
let expected = Type::Function(args, Box::new(ret.clone()), Box::new(env_type));

if let Err(error) = binding.borrow_mut().bind_to(expected, span) {
if let Err(error) = binding.try_bind(expected, span) {
self.errors.push(error);
}
ret
Expand Down Expand Up @@ -1077,7 +1074,11 @@ impl<'interner> TypeChecker<'interner> {
}));
}

if other.try_bind_to_polymorphic_int(int).is_ok() || other == &Type::Error {
let mut bindings = TypeBindings::new();
if other.try_bind_to_polymorphic_int(int, &mut bindings).is_ok()
|| other == &Type::Error
{
Type::apply_type_bindings(bindings);
Ok(other.clone())
} else {
Err(TypeCheckError::TypeMismatchWithSource {
Expand Down
5 changes: 1 addition & 4 deletions compiler/noirc_frontend/src/hir/type_check/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::hir_def::stmt::{
};
use crate::hir_def::types::Type;
use crate::node_interner::{DefinitionId, ExprId, StmtId};
use crate::{Shared, TypeBinding, TypeVariableKind};

use super::errors::{Source, TypeCheckError};
use super::TypeChecker;
Expand Down Expand Up @@ -71,9 +70,7 @@ impl<'interner> TypeChecker<'interner> {
expr_span: range_span,
});

let fresh_id = self.interner.next_type_variable_id();
let type_variable = Shared::new(TypeBinding::Unbound(fresh_id));
let expected_type = Type::TypeVariable(type_variable, TypeVariableKind::IntegerOrField);
let expected_type = Type::polymorphic_integer(self.interner);

self.unify(&start_range_type, &expected_type, || {
TypeCheckError::TypeCannotBeUsed {
Expand Down
41 changes: 29 additions & 12 deletions compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::rc::Rc;

use crate::{
graph::CrateId,
node_interner::{FuncId, TraitId, TraitMethodId},
Expand All @@ -11,9 +9,7 @@ use noirc_errors::Span;
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TraitFunction {
pub name: Ident,
pub generics: Vec<(Rc<String>, TypeVariable, Span)>,
pub arguments: Vec<Type>,
pub return_type: Type,
pub typ: Type,
pub span: Span,
pub default_impl: Option<Box<NoirFunction>>,
pub default_impl_file_id: fm::FileId,
Expand Down Expand Up @@ -145,12 +141,33 @@ impl std::fmt::Display for Trait {
}

impl TraitFunction {
pub fn get_type(&self) -> Type {
Type::Function(
self.arguments.clone(),
Box::new(self.return_type.clone()),
Box::new(Type::Unit),
)
.generalize()
pub fn arguments(&self) -> &[Type] {
match &self.typ {
Type::Function(args, _, _) => args,
Type::Forall(_, typ) => match typ.as_ref() {
Type::Function(args, _, _) => args,
_ => unreachable!("Trait function does not have a function type"),
},
_ => unreachable!("Trait function does not have a function type"),
}
}

pub fn generics(&self) -> &[(TypeVariableId, TypeVariable)] {
match &self.typ {
Type::Function(..) => &[],
Type::Forall(generics, _) => generics,
_ => unreachable!("Trait function does not have a function type"),
}
}

pub fn return_type(&self) -> &Type {
match &self.typ {
Type::Function(_, return_type, _) => return_type,
Type::Forall(_, typ) => match typ.as_ref() {
Type::Function(_, return_type, _) => return_type,
_ => unreachable!("Trait function does not have a function type"),
},
_ => unreachable!("Trait function does not have a function type"),
}
}
}
Loading