Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/noirc_frontend/src/elaborator/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ impl<'context> Elaborator<'context> {
UnresolvedTypeExpression::Constant(0, span)
});

let length = self.convert_expression_type(length, span);
let length = self.convert_expression_type(length, &Kind::u32(), span);
let (repeated_element, elem_type) = self.elaborate_expression(*repeated_element);

let length_clone = length.clone();
Expand Down
57 changes: 31 additions & 26 deletions compiler/noirc_frontend/src/elaborator/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,22 @@ impl<'context> Elaborator<'context> {
FieldElement => Type::FieldElement,
Array(size, elem) => {
let elem = Box::new(self.resolve_type_inner(*elem, kind));
let size = self.convert_expression_type(size, span);
let size = self.convert_expression_type(size, &Kind::u32(), span);
Type::Array(Box::new(size), elem)
}
Slice(elem) => {
let elem = Box::new(self.resolve_type_inner(*elem, kind));
Type::Slice(elem)
}
Expression(expr) => self.convert_expression_type(expr, span),
Expression(expr) => self.convert_expression_type(expr, kind, span),
Integer(sign, bits) => Type::Integer(sign, bits),
Bool => Type::Bool,
String(size) => {
let resolved_size = self.convert_expression_type(size, span);
let resolved_size = self.convert_expression_type(size, &Kind::u32(), span);
Type::String(Box::new(resolved_size))
}
FormatString(size, fields) => {
let resolved_size = self.convert_expression_type(size, span);
let resolved_size = self.convert_expression_type(size, &Kind::u32(), span);
let fields = self.resolve_type_inner(*fields, kind);
Type::FmtString(Box::new(resolved_size), Box::new(fields))
}
Expand Down Expand Up @@ -426,37 +426,25 @@ impl<'context> Elaborator<'context> {
pub(super) fn convert_expression_type(
&mut self,
length: UnresolvedTypeExpression,
expected_kind: &Kind,
span: Span,
) -> Type {
match length {
UnresolvedTypeExpression::Variable(path) => {
let resolved_length =
self.lookup_generic_or_global_type(&path).unwrap_or_else(|| {
self.push_err(ResolverError::NoSuchNumericTypeVariable { path });
Type::Constant(0, Kind::u32())
});

if let Type::NamedGeneric(ref _type_var, ref _name, ref kind) = resolved_length {
if !kind.is_numeric() {
self.push_err(TypeCheckError::TypeKindMismatch {
expected_kind: Kind::u32().to_string(),
expr_kind: kind.to_string(),
expr_span: span,
});
return Type::Error;
}
}
resolved_length
let typ = self.resolve_named_type(path, GenericTypeArgs::default());
self.check_kind(typ, expected_kind, span)
}
UnresolvedTypeExpression::Constant(int, _span) => {
Type::Constant(int, expected_kind.clone())
}
UnresolvedTypeExpression::Constant(int, _span) => Type::Constant(int, Kind::u32()),
UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, span) => {
let (lhs_span, rhs_span) = (lhs.span(), rhs.span());
let lhs = self.convert_expression_type(*lhs, lhs_span);
let rhs = self.convert_expression_type(*rhs, rhs_span);
let lhs = self.convert_expression_type(*lhs, expected_kind, lhs_span);
let rhs = self.convert_expression_type(*rhs, expected_kind, rhs_span);

match (lhs, rhs) {
(Type::Constant(lhs, lhs_kind), Type::Constant(rhs, rhs_kind)) => {
if lhs_kind != rhs_kind {
if !lhs_kind.unifies(&rhs_kind) {
self.push_err(TypeCheckError::TypeKindMismatch {
expected_kind: lhs_kind.to_string(),
expr_kind: rhs_kind.to_string(),
Expand All @@ -474,10 +462,27 @@ impl<'context> Elaborator<'context> {
(lhs, rhs) => Type::InfixExpr(Box::new(lhs), op, Box::new(rhs)).canonicalize(),
}
}
UnresolvedTypeExpression::AsTraitPath(path) => self.resolve_as_trait_path(*path),
UnresolvedTypeExpression::AsTraitPath(path) => {
let typ = self.resolve_as_trait_path(*path);
self.check_kind(typ, expected_kind, span)
}
}
}

fn check_kind(&mut self, typ: Type, expected_kind: &Kind, span: Span) -> Type {
if let Some(kind) = typ.kind() {
if !kind.unifies(expected_kind) {
self.push_err(TypeCheckError::TypeKindMismatch {
expected_kind: expected_kind.to_string(),
expr_kind: kind.to_string(),
expr_span: span,
});
return Type::Error;
}
}
typ
}

fn resolve_as_trait_path(&mut self, path: AsTraitPath) -> Type {
let span = path.trait_path.span;
let Some(trait_id) = self.resolve_trait_by_path(path.trait_path.clone()) else {
Expand Down
30 changes: 27 additions & 3 deletions compiler/noirc_frontend/src/hir_def/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,36 @@ impl Kind {
}

pub(crate) fn matches_opt(&self, other: Option<Self>) -> bool {
other.as_ref().map_or(true, |other_kind| self == other_kind)
other.as_ref().map_or(true, |other_kind| self.unifies(other_kind))
}

pub(crate) fn u32() -> Self {
Self::Numeric(Box::new(Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo)))
}

/// Unifies this kind with the other. Returns true on success
pub(crate) fn unifies(&self, other: &Kind) -> bool {
match (self, other) {
(Kind::Normal, Kind::Normal) => true,
(Kind::Numeric(lhs), Kind::Numeric(rhs)) => {
let mut bindings = TypeBindings::new();
let unifies = lhs.try_unify(rhs, &mut bindings).is_ok();
if unifies {
Type::apply_type_bindings(bindings);
}
unifies
}
_ => false,
}
}

pub(crate) fn unify(&self, other: &Kind) -> Result<(), UnificationError> {
if self.unifies(other) {
Ok(())
} else {
Err(UnificationError)
}
}
}

impl std::fmt::Display for Kind {
Expand Down Expand Up @@ -1465,13 +1489,13 @@ impl Type {
}
}

(NamedGeneric(binding_a, name_a, _), NamedGeneric(binding_b, name_b, _)) => {
(NamedGeneric(binding_a, name_a, kind_a), NamedGeneric(binding_b, name_b, kind_b)) => {
// Bound NamedGenerics are caught by the check above
assert!(binding_a.borrow().is_unbound());
assert!(binding_b.borrow().is_unbound());

if name_a == name_b {
Ok(())
kind_a.unify(kind_b)
} else {
Err(UnificationError)
}
Expand Down
68 changes: 62 additions & 6 deletions compiler/noirc_frontend/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1616,25 +1616,30 @@ fn numeric_generic_binary_operation_type_mismatch() {
#[test]
fn bool_generic_as_loop_bound() {
let src = r#"
pub fn read<let N: bool>() {
let mut fields = [0; N];
for i in 0..N {
pub fn read<let N: bool>() { // error here
let mut fields = [0; N]; // error here
for i in 0..N { // error here
fields[i] = i + 1;
}
assert(fields[0] == 1);
}
"#;
let errors = get_program_errors(src);
assert_eq!(errors.len(), 2);
assert_eq!(errors.len(), 3);

assert!(matches!(
errors[0].0,
CompilationError::ResolverError(ResolverError::UnsupportedNumericGenericType { .. }),
));

assert!(matches!(
errors[1].0,
CompilationError::TypeError(TypeCheckError::TypeKindMismatch { .. }),
));

let CompilationError::TypeError(TypeCheckError::TypeMismatch {
expected_typ, expr_typ, ..
}) = &errors[1].0
}) = &errors[2].0
else {
panic!("Got an error other than a type mismatch");
};
Expand All @@ -1646,7 +1651,7 @@ fn bool_generic_as_loop_bound() {
#[test]
fn numeric_generic_in_function_signature() {
let src = r#"
pub fn foo<let N: u8>(arr: [Field; N]) -> [Field; N] { arr }
pub fn foo<let N: u32>(arr: [Field; N]) -> [Field; N] { arr }
"#;
assert_no_errors(src);
}
Expand Down Expand Up @@ -3644,3 +3649,54 @@ fn does_not_crash_when_passing_mutable_undefined_variable() {

assert_eq!(name, "undefined");
}

#[test]
fn infer_globals_to_u32_from_type_use() {
let src = r#"
global ARRAY_LEN = 3;
global STR_LEN = 2;
global FMT_STR_LEN = 2;

fn main() {
let _a: [u32; ARRAY_LEN] = [1, 2, 3];
let _b: str<STR_LEN> = "hi";
let _c: fmtstr<FMT_STR_LEN, _> = f"hi";
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 0);
}

#[test]
fn non_u32_in_array_length() {
let src = r#"
global ARRAY_LEN: u8 = 3;

fn main() {
let _a: [u32; ARRAY_LEN] = [1, 2, 3];
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 1);

assert!(matches!(
errors[0].0,
CompilationError::TypeError(TypeCheckError::TypeKindMismatch { .. })
));
}

#[test]
fn use_non_u32_generic_in_struct() {
let src = r#"
struct S<let N: u8> {}

fn main() {
let _: S<3> = S {};
}
"#;

let errors = get_program_errors(src);
assert_eq!(errors.len(), 0);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn main() {
}

// Used in the signature of a function
fn id<let I: Field>(x: [Field; I]) -> [Field; I] {
fn id<let I: u32>(x: [Field; I]) -> [Field; I] {
x
}

Expand Down