From b0cc5ad69ae24f41f23f6de42c4528c976a8dc08 Mon Sep 17 00:00:00 2001 From: "Michael J. Klein" Date: Wed, 1 May 2024 20:58:58 -0400 Subject: [PATCH 1/3] wip: add PartialEq for debugging, add error for disallowed generic arith operation, move constant/named-generic into sub-struct, add unit and test_programs tests --- .../noirc_frontend/src/hir/comptime/errors.rs | 3 +- .../src/hir/def_collector/dc_crate.rs | 3 +- .../src/hir/def_collector/errors.rs | 6 +- .../src/hir/resolution/errors.rs | 7 + .../src/hir/resolution/resolver.rs | 105 +++++++++--- compiler/noirc_frontend/src/hir_def/types.rs | 158 +++++++++++++----- compiler/noirc_frontend/src/tests.rs | 10 ++ .../numeric_generic_arith/Nargo.toml | 7 + .../numeric_generic_arith/src/main.nr | 2 + 9 files changed, 228 insertions(+), 73 deletions(-) create mode 100644 test_programs/compile_success_empty/numeric_generic_arith/Nargo.toml create mode 100644 test_programs/compile_success_empty/numeric_generic_arith/src/main.nr diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index af5ba9a44cf..1f9a7e6f6e3 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -5,7 +5,8 @@ use noirc_errors::{CustomDiagnostic, Location}; use super::value::Value; /// The possible errors that can halt the interpreter. -#[derive(Debug, Clone)] +// TODO: revert PartialEq +#[derive(Debug, Clone, PartialEq)] pub enum InterpreterError { ArgumentCountMismatch { expected: usize, actual: usize, location: Location }, TypeMismatch { expected: Type, value: Value, location: Location }, diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 2f6b101e62f..a3920a318ad 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -149,7 +149,8 @@ pub struct DefCollector { pub(crate) type ImplMap = HashMap<(UnresolvedType, LocalModuleId), Vec<(UnresolvedGenerics, Span, UnresolvedFunctions)>>; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] +// TODO: revert PartialEq pub enum CompilationError { ParseError(ParserError), DefinitionError(DefCollectorErrorKind), diff --git a/compiler/noirc_frontend/src/hir/def_collector/errors.rs b/compiler/noirc_frontend/src/hir/def_collector/errors.rs index edeb463e10d..b8fe12e990e 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/errors.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/errors.rs @@ -22,7 +22,8 @@ pub enum DuplicateType { TraitAssociatedFunction, } -#[derive(Error, Debug, Clone)] +// TODO: revert PartialEq +#[derive(Error, Debug, Clone, PartialEq)] pub enum DefCollectorErrorKind { #[error("duplicate {typ} found in namespace")] Duplicate { typ: DuplicateType, first_def: Ident, second_def: Ident }, @@ -69,7 +70,8 @@ pub enum DefCollectorErrorKind { } /// An error struct that macro processors can return. -#[derive(Debug, Clone)] +// TODO: revert PartialEq +#[derive(Debug, Clone, PartialEq)] pub struct MacroError { pub primary_message: String, pub secondary_message: Option, diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index 1727471c34f..16a7b0c6b4c 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -46,6 +46,8 @@ pub enum ResolverError { MissingRhsExpr { name: String, span: Span }, #[error("Expression invalid in an array length context")] InvalidArrayLengthExpr { span: Span }, + #[error("Expression invalid for generic arithmetic: only '+', '-', and '*' are allowed.")] + InvalidGenericArithOp { span: Span }, #[error("Integer too large to be evaluated in an array length context")] IntegerTooLarge { span: Span }, #[error("No global or generic type parameter found with the given name")] @@ -236,6 +238,11 @@ impl<'a> From<&'a ResolverError> for Diagnostic { "Array-length expressions can only have simple integer operations and any variables used must be global constants".into(), *span, ), + ResolverError::InvalidGenericArithOp { span } => Diagnostic::simple_error( + "Expression invalid for generic arithmetic with variables: only '+', '-', and '*' are allowed.".into(), + "Generic expressions can only have constants, variables, and simple integer operations: '+', '-', and '*'".into(), + *span, + ), ResolverError::IntegerTooLarge { span } => Diagnostic::simple_error( "Integer too large to be evaluated to an array-length".into(), "Array-lengths may be a maximum size of usize::MAX, including intermediate calculations".into(), diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index bef0ebdaacc..062a4d840f7 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -40,7 +40,7 @@ use crate::node_interner::{ DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, NodeInterner, StmtId, StructId, TraitId, TraitImplId, TraitMethodId, TypeAliasId, }; -use crate::{Generics, Shared, StructType, Type, TypeAlias, TypeVariable, TypeVariableKind}; +use crate::{BinaryTypeOperator, Generics, GenericArith, GenericArithOpKind, Shared, StructType, Type, TypeAlias, TypeVariable, TypeVariableKind}; use fm::FileId; use iter_extended::vecmap; use noirc_errors::{Location, Span, Spanned}; @@ -586,7 +586,7 @@ impl<'a> Resolver<'a> { let env = Box::new(self.resolve_type_inner(*env, new_variables)); match *env { - Type::Unit | Type::Tuple(_) | Type::NamedGeneric(_, _) => { + Type::Unit | Type::Tuple(_) | Type::GenericArith(GenericArith::NamedGeneric(..), _) => { Type::Function(args, ret, env) } _ => { @@ -753,7 +753,7 @@ impl<'a> Resolver<'a> { if path.segments.len() == 1 { let name = &path.last_segment().0.contents; if let Some((name, var, _)) = self.find_generic(name) { - return Some(Type::NamedGeneric(var.clone(), name.clone())); + return Some(Type::GenericArith(GenericArith::NamedGeneric(var.clone(), name.clone()), Shared::new(vec![]))); } } @@ -767,7 +767,7 @@ impl<'a> Resolver<'a> { if let Some(error) = error { self.push_err(error.into()); } - Some(Type::Constant(self.eval_global_as_array_length(id, path))) + Some(Type::GenericArith(GenericArith::Constant(self.eval_global_as_array_length(id, path)))) } _ => None, } @@ -787,7 +787,7 @@ impl<'a> Resolver<'a> { // 'Named'Generic is a bit of a misnomer here, we want a type variable that // wont be bound over but this one has no name since we do not currently // require users to explicitly be generic over array lengths. - Type::NamedGeneric(typevar, Rc::new("".into())) + Type::GenericArith(GenericArith::NamedGeneric(typevar, Rc::new("".into()))) } Some(length) => self.convert_expression_type(length), } @@ -798,25 +798,70 @@ impl<'a> Resolver<'a> { UnresolvedTypeExpression::Variable(path) => { self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); - Type::Constant(0) + Type::GenericArith(GenericArith::Constant(0)) }) } - UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int), - UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { + UnresolvedTypeExpression::Constant(int, _) => Type::GenericArith(GenericArith::Constant(int)), + UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, op_span) => { let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); let lhs = self.convert_expression_type(*lhs); let rhs = self.convert_expression_type(*rhs); match (lhs, rhs) { - (Type::Constant(lhs), Type::Constant(rhs)) => { - Type::Constant(op.function()(lhs, rhs)) + // (Type::Constant(lhs), Type::Constant(rhs)) => { + // Type::Constant(op.function()(lhs, rhs)) + // } + + (Type::GenericArith(rhs), Type::GenericArith(lhs)) => { + match (lhs, rhs) { + (GenericArith::Constant(lhs), GenericArith::Constant(rhs)) => { + Type::GenericArith(GenericArith::Constant(op.function()(lhs, rhs))) + } + (lhs, rhs) => { + match op { + BinaryTypeOperator::Addition => { + Type::GenericArith(GenericArith::Op { + kind: GenericArithOpKind::Add, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + }) + } + BinaryTypeOperator::Multiplication => { + Type::GenericArith(GenericArith::Op { + kind: GenericArithOpKind::Mul, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + }) + } + BinaryTypeOperator::Subtraction => { + Type::GenericArith(GenericArith::Op { + kind: GenericArithOpKind::Sub, + lhs: Box::new(lhs), + rhs: Box::new(rhs), + }) + } + _ => { + self.push_err(ResolverError::InvalidGenericArithOp { span: op_span }); + Type::GenericArith(GenericArith::Constant(0)) + } + } + } + } + + } - (lhs, _) => { - let span = - if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span }; - self.push_err(ResolverError::InvalidArrayLengthExpr { span }); - Type::Constant(0) + + (lhs, rhs) => { + + // TODO cleanup + // let span = + // if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span }; + // // TODO: revert before PR + // dbg!("convert_expression_type", &lhs, &rhs); + self.push_err(ResolverError::InvalidArrayLengthExpr { span: op_span }); + Type::GenericArith(GenericArith::Constant(0)) } + } } } @@ -1195,8 +1240,9 @@ impl<'a> Resolver<'a> { | Type::Unit | Type::Error | Type::TypeVariable(_, _) - | Type::Constant(_) - | Type::NamedGeneric(_, _) + // TODO + // | Type::Constant(_) + // | Type::NamedGeneric(_, _) | Type::Code | Type::Forall(_, _) => (), @@ -1207,7 +1253,7 @@ impl<'a> Resolver<'a> { } Type::Array(length, element_type) => { - if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = length.as_ref() { found.insert(name.to_string(), type_variable.clone()); } Self::find_numeric_generics_in_type(element_type, found); @@ -1232,7 +1278,7 @@ impl<'a> Resolver<'a> { Type::Struct(struct_type, generics) => { for (i, generic) in generics.iter().enumerate() { - if let Type::NamedGeneric(type_variable, name) = generic { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = generic { if struct_type.borrow().generic_is_numeric(i) { found.insert(name.to_string(), type_variable.clone()); } @@ -1243,7 +1289,7 @@ impl<'a> Resolver<'a> { } Type::Alias(alias, generics) => { for (i, generic) in generics.iter().enumerate() { - if let Type::NamedGeneric(type_variable, name) = generic { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = generic { if alias.borrow().generic_is_numeric(i) { found.insert(name.to_string(), type_variable.clone()); } @@ -1254,12 +1300,12 @@ impl<'a> Resolver<'a> { } Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found), Type::String(length) => { - if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = length.as_ref() { found.insert(name.to_string(), type_variable.clone()); } } Type::FmtString(length, fields) => { - if let Type::NamedGeneric(type_variable, name) = length.as_ref() { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = length.as_ref() { found.insert(name.to_string(), type_variable.clone()); } Self::find_numeric_generics_in_type(fields, found); @@ -2109,10 +2155,16 @@ impl<'a> Resolver<'a> { let expression = let_statement.expression; self.try_eval_array_length_id_with_fuel(expression, span, fuel - 1) } else { + // TODO: revert before PR + dbg!("Global", definition); Err(Some(ResolverError::InvalidArrayLengthExpr { span })) } } - _ => Err(Some(ResolverError::InvalidArrayLengthExpr { span })), + _ => { + // TODO: revert before PR + dbg!("Ident", definition); + Err(Some(ResolverError::InvalidArrayLengthExpr { span })) + } } } HirExpression::Infix(infix) => { @@ -2138,7 +2190,12 @@ impl<'a> Resolver<'a> { BinaryOpKind::Modulo => Ok(lhs % rhs), } } - _other => Err(Some(ResolverError::InvalidArrayLengthExpr { span })), + _other => { + // TODO: revert before PR + dbg!("_other", _other); + + Err(Some(ResolverError::InvalidArrayLengthExpr { span })) + } } } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 637f3c99e89..8dc8795f6fe 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -27,7 +27,7 @@ pub enum Type { FieldElement, /// Array(N, E) is an array of N elements of type E. It is expected that N - /// is either a type variable of some kind or a Type::Constant. + /// is either a type variable of some kind or a Type::GenericArith. Array(Box, Box), /// Slice(E) is a slice of elements of type E. @@ -41,7 +41,7 @@ pub enum Type { Bool, /// String(N) is an array of characters of length N. It is expected that N - /// is either a type variable of some kind or a Type::Constant. + /// is either a type variable of some kind or a Type::GenericArith(GenericArith::Constant). String(Box), /// FmtString(N, Vec) is an array of characters of length N that contains @@ -80,9 +80,10 @@ pub enum Type { /// used for displaying error messages using the name of the trait. TraitAsType(TraitId, /*name:*/ Rc, /*generics:*/ Vec), - /// NamedGenerics are the 'T' or 'U' in a user-defined generic function - /// like `fn foo(...) {}`. Unlike TypeVariables, they cannot be bound over. - NamedGeneric(TypeVariable, Rc), + // TODO: cleanup before PR + // /// NamedGenerics are the 'T' or 'U' in a user-defined generic function + // /// like `fn foo(...) {}`. Unlike TypeVariables, they cannot be bound over. + // NamedGeneric(TypeVariable, Rc), /// A functions with arguments, a return type and environment. /// the environment should be `Unit` by default, @@ -100,9 +101,11 @@ pub enum Type { /// will be and thus needs the full TypeVariable link. Forall(Generics, Box), - /// A type-level integer. Included to let an Array's size type variable - /// bind to an integer without special checks to bind it to a non-type. - Constant(u64), + // /// A type-level integer. Included to let an Array's size type variable + // /// bind to an integer without special checks to bind it to a non-type. + // Constant(u64), + + GenericArith(GenericArith, /*equiv_generics*/ Shared>), /// The type of quoted code in macros. This is always a comptime-only type Code, @@ -114,6 +117,53 @@ pub enum Type { Error, } +// TODO: relocate +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub enum GenericArith { + Op { + kind: GenericArithOpKind, + lhs: Box, + rhs: Box, + }, + NamedGeneric(TypeVariable, Rc), + Constant(u64), +} + +// TODO: relocate +impl std::fmt::Display for GenericArith { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GenericArith::Op { kind, lhs, rhs } => write!(f, "{lhs} {kind} {rhs}"), + + GenericArith::NamedGeneric(binding, name) => match &*binding.borrow() { + TypeBinding::Bound(binding) => binding.fmt(f), + TypeBinding::Unbound(_) if name.is_empty() => write!(f, "_"), + TypeBinding::Unbound(_) => write!(f, "{name}"), + }, + GenericArith::Constant(x) => x.fmt(f), + } + } +} + +// TODO: relocate +#[derive(PartialEq, Eq, Clone, Hash)] +pub enum GenericArithOpKind { + Mul, + Add, + Sub, +} + +// TODO: relocate +impl std::fmt::Display for GenericArithOpKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GenericArithOpKind::Mul => write!(f, "*"), + GenericArithOpKind::Add => write!(f, "+"), + GenericArithOpKind::Sub => write!(f, "-"), + } + } +} + impl Type { /// Returns the number of field elements required to represent the type once encoded. pub fn field_count(&self) -> u32 { @@ -145,11 +195,10 @@ impl Type { | Type::Unit | Type::TypeVariable(_, _) | Type::TraitAsType(..) - | Type::NamedGeneric(_, _) | Type::Function(_, _, _) | Type::MutableReference(_) | Type::Forall(_, _) - | Type::Constant(_) + | Type::GenericArith(_) | Type::Code | Type::Slice(_) | Type::Error => unreachable!("This type cannot exist as a parameter to main"), @@ -454,9 +503,11 @@ pub enum TypeVariableKind { /// that can only be bound to Type::Integer, or other polymorphic integers. Integer, + // TODO: update doc /// A potentially constant array size. This will only bind to itself or - /// Type::Constant(n) with a matching size. This defaults to Type::Constant(n) if still unbound - /// during monomorphization. + /// Type::GenericArith(GenericArith::Constant(n)) with a matching size. + /// This defaults to Type::GenericArith(GenericArith::Constant(n)) if still + /// unbound during monomorphization. Constant(u64), } @@ -621,7 +672,7 @@ impl Type { fn contains_numeric_typevar(&self, target_id: TypeVariableId) -> bool { // True if the given type is a NamedGeneric with the target_id let named_generic_id_matches_target = |typ: &Type| { - if let Type::NamedGeneric(type_variable, _) = typ { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, _)) = typ { match &*type_variable.borrow() { TypeBinding::Bound(_) => { unreachable!("Named generics should not be bound until monomorphization") @@ -640,8 +691,7 @@ impl Type { | Type::Unit | Type::Error | Type::TypeVariable(_, _) - | Type::Constant(_) - | Type::NamedGeneric(_, _) + | Type::GenericArith(_) | Type::Forall(_, _) | Type::Code => false, @@ -700,12 +750,12 @@ impl Type { | Type::Integer(_, _) | Type::Bool | Type::Unit - | Type::Constant(_) + | Type::GenericArith(GenericArith::Constant(_)) | Type::Error => true, Type::FmtString(_, _) | Type::TypeVariable(_, _) - | Type::NamedGeneric(_, _) + | Type::GenericArith(_) | Type::Function(_, _, _) | Type::MutableReference(_) | Type::Forall(_, _) @@ -745,9 +795,9 @@ impl Type { | Type::Integer(_, _) | Type::Bool | Type::Unit - | Type::Constant(_) | Type::TypeVariable(_, _) - | Type::NamedGeneric(_, _) + // TODO + | Type::GenericArith(_) | Type::Error => true, Type::FmtString(_, _) @@ -784,7 +834,7 @@ impl Type { pub fn generic_count(&self) -> usize { match self { Type::Forall(generics, _) => generics.len(), - Type::TypeVariable(type_variable, _) | Type::NamedGeneric(type_variable, _) => { + Type::TypeVariable(type_variable, _) | Type::GenericArith(GenericArith::NamedGeneric(type_variable, _)) => { match &*type_variable.borrow() { TypeBinding::Bound(binding) => binding.generic_count(), TypeBinding::Unbound(_) => 0, @@ -898,12 +948,7 @@ impl std::fmt::Display for Type { } Type::Unit => write!(f, "()"), Type::Error => write!(f, "error"), - Type::NamedGeneric(binding, name) => match &*binding.borrow() { - TypeBinding::Bound(binding) => binding.fmt(f), - TypeBinding::Unbound(_) if name.is_empty() => write!(f, "_"), - TypeBinding::Unbound(_) => write!(f, "{name}"), - }, - Type::Constant(x) => x.fmt(f), + Type::GenericArith(generic_arith) => write!(f, "{}", generic_arith), Type::Forall(typevars, typ) => { let typevars = vecmap(typevars, |var| var.id().to_string()); write!(f, "forall {}. {}", typevars.join(" "), typ) @@ -973,7 +1018,7 @@ impl Type { let this = self.substitute(bindings).follow_bindings(); match &this { - Type::Constant(length) if *length == target_length => { + Type::GenericArith(GenericArith::Constant(length)) if *length == target_length => { bindings.insert(target_id, (var.clone(), this)); Ok(()) } @@ -1138,7 +1183,7 @@ impl Type { fn get_inner_type_variable(&self) -> Option> { match self { - Type::TypeVariable(var, _) | Type::NamedGeneric(var, _) => Some(var.1.clone()), + Type::TypeVariable(var, _) | Type::GenericArith(GenericArith::NamedGeneric(var, _)) => Some(var.1.clone()), _ => None, } } @@ -1172,6 +1217,7 @@ impl Type { bindings: &mut TypeBindings, ) -> Result<(), UnificationError> { use Type::*; + use crate::GenericArith; use TypeVariableKind as Kind; match (self, other) { @@ -1249,7 +1295,7 @@ impl Type { } } - (NamedGeneric(binding, _), other) | (other, NamedGeneric(binding, _)) + (GenericArith(GenericArith::NamedGeneric(binding, _), _), other) | (other, GenericArith(GenericArith::NamedGeneric(binding, _), _)) if !binding.borrow().is_unbound() => { if let TypeBinding::Bound(link) = &*binding.borrow() { @@ -1259,18 +1305,31 @@ impl Type { } } - (NamedGeneric(binding_a, name_a), NamedGeneric(binding_b, name_b)) => { + (GenericArith(GenericArith::NamedGeneric(binding_a, name_a), equiv_a), GenericArith(GenericArith::NamedGeneric(binding_b, name_b), equiv_b)) => { // Bound NamedGenerics are caught by the check above assert!(binding_a.borrow().is_unbound()); assert!(binding_b.borrow().is_unbound()); + // TODO: cleanup condition if name_a == name_b { Ok(()) } else { - Err(UnificationError) + // Err(UnificationError) + (*equiv_a).push(GenericArith::NamedGeneric(binding_b, name_b)); + Ok(()) } } + (GenericArith(GenericArith::Op { kind: kind_a, lhs: lhs_a, rhs: rhs_a }, equiv_a), GenericArith(GenericArith::Op { kind: kind_b, lhs: lhs_b, rhs: rhs_b }, equiv_b)) => { + + () + } + + (GenericArith(GenericArith::Op { kind, lhs, rhs }, equiv), other) | (other, GenericArith(GenericArith::Op { kind, lhs, rhs }, equiv)) => { + + () + } + (Function(params_a, ret_a, env_a), Function(params_b, ret_b, env_b)) => { if params_a.len() == params_b.len() { for (a, b) in params_a.iter().zip(params_b.iter()) { @@ -1382,8 +1441,8 @@ impl Type { } } - /// If this type is a Type::Constant (used in array lengths), or is bound - /// to a Type::Constant, return the constant as a u64. + /// If this type is a Type::GenericArith(GenericArith::Constant) (used in array lengths), or is bound + /// to a Type::GenericArith(GenericArith::Constant), return the constant as a u64. pub fn evaluate_to_u64(&self) -> Option { if let Some(binding) = self.get_inner_type_variable() { if let TypeBinding::Bound(binding) = &*binding.borrow() { @@ -1394,7 +1453,7 @@ impl Type { match self { Type::TypeVariable(_, TypeVariableKind::Constant(size)) => Some(*size), Type::Array(len, _elem) => len.evaluate_to_u64(), - Type::Constant(x) => Some(*x), + Type::GenericArith(GenericArith::Constant(x)) => Some(*x), _ => None, } } @@ -1549,7 +1608,7 @@ impl Type { let fields = fields.substitute_helper(type_bindings, substitute_bound_typevars); Type::FmtString(Box::new(size), Box::new(fields)) } - Type::NamedGeneric(binding, _) | Type::TypeVariable(binding, _) => { + Type::GenericArith(GenericArith::NamedGeneric(binding, _)) | Type::TypeVariable(binding, _) => { substitute_binding(binding) } // Do not substitute_helper fields, it can lead to infinite recursion @@ -1603,7 +1662,7 @@ impl Type { Type::FieldElement | Type::Integer(_, _) | Type::Bool - | Type::Constant(_) + | Type::GenericArith(GenericArith::Constant(_)) | Type::Error | Type::Code | Type::Unit => self.clone(), @@ -1627,7 +1686,7 @@ impl Type { generic_args.iter().any(|arg| arg.occurs(target_id)) } Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)), - Type::NamedGeneric(binding, _) | Type::TypeVariable(binding, _) => { + Type::GenericArith(GenericArith::NamedGeneric(binding, _)) | Type::TypeVariable(binding, _) => { match &*binding.borrow() { TypeBinding::Bound(binding) => binding.occurs(target_id), TypeBinding::Unbound(id) => *id == target_id, @@ -1646,7 +1705,7 @@ impl Type { Type::FieldElement | Type::Integer(_, _) | Type::Bool - | Type::Constant(_) + | Type::GenericArith(GenericArith::Constant(_)) | Type::Error | Type::Code | Type::Unit => false, @@ -1661,6 +1720,7 @@ impl Type { /// Expected to be called on an instantiated type (with no Type::Foralls) pub fn follow_bindings(&self) -> Type { use Type::*; + use crate::GenericArith; match self { Array(size, elem) => { Array(Box::new(size.follow_bindings()), Box::new(elem.follow_bindings())) @@ -1682,7 +1742,7 @@ impl Type { def.borrow().get_type(args).follow_bindings() } Tuple(args) => Tuple(vecmap(args, |arg| arg.follow_bindings())), - TypeVariable(var, _) | NamedGeneric(var, _) => { + TypeVariable(var, _) | GenericArith(GenericArith::NamedGeneric(var, _), _) => { if let TypeBinding::Bound(typ) = &*var.borrow() { return typ.follow_bindings(); } @@ -1705,7 +1765,7 @@ impl Type { // Expect that this function should only be called on instantiated types Forall(..) => unreachable!(), - FieldElement | Integer(_, _) | Bool | Constant(_) | Unit | Code | Error => self.clone(), + FieldElement | Integer(_, _) | Bool | GenericArith(GenericArith::Constant(_), _) | Unit | Code | Error => self.clone(), } } @@ -1768,7 +1828,7 @@ impl TypeVariableKind { match self { TypeVariableKind::IntegerOrField => Some(Type::default_int_or_field_type()), TypeVariableKind::Integer => Some(Type::default_int_type()), - TypeVariableKind::Constant(length) => Some(Type::Constant(*length)), + TypeVariableKind::Constant(length) => Some(Type::GenericArith(GenericArith::Constant(*length))), TypeVariableKind::Normal => None, } } @@ -1819,7 +1879,7 @@ impl From<&Type> for PrintableType { Type::FmtString(_, _) => unreachable!("format strings cannot be printed"), Type::Error => unreachable!(), Type::Unit => PrintableType::Unit, - Type::Constant(_) => unreachable!(), + Type::GenericArith(GenericArith::Constant(_)) => unreachable!(), Type::Struct(def, ref args) => { let struct_type = def.borrow(); let fields = struct_type.get_fields(args); @@ -1830,7 +1890,7 @@ impl From<&Type> for PrintableType { Type::TraitAsType(_, _, _) => unreachable!(), Type::Tuple(types) => PrintableType::Tuple { types: vecmap(types, |typ| typ.into()) }, Type::TypeVariable(_, _) => unreachable!(), - Type::NamedGeneric(..) => unreachable!(), + Type::GenericArith(GenericArith::NamedGeneric(..)) => unreachable!(), Type::Forall(..) => unreachable!(), Type::Function(arguments, return_type, env) => PrintableType::Function { arguments: arguments.iter().map(|arg| arg.into()).collect(), @@ -1906,8 +1966,16 @@ impl std::fmt::Debug for Type { } Type::Unit => write!(f, "()"), Type::Error => write!(f, "error"), - Type::NamedGeneric(binding, name) => write!(f, "{}{:?}", name, binding), - Type::Constant(x) => x.fmt(f), + Type::GenericArith(GenericArith::NamedGeneric(binding, name), equiv) => { + (write!(f, "{}{:?}", name, binding), equiv) + }, + Type::GenericArith(GenericArith::Constant(x), equiv) => { + if equiv.borrow().len() == 0 { + x.fmt(f) + } else { + write!(f, "{}({:?})", x, equiv.borrow()) + } + }, Type::Forall(typevars, typ) => { let typevars = vecmap(typevars, |var| format!("{:?}", var)); write!(f, "forall {}. {:?}", typevars.join(" "), typ) diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index b2cc7eee9f8..31a5546d19c 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -1368,6 +1368,16 @@ fn lambda$f1(mut env$l1: (Field)) -> Field { assert_eq!(get_program_errors(src).len(), 0); } + #[test] + fn numeric_generic_arith() { + let src = r#" + type Outer = [u8; 2 * N + 1]; + fn main(_arg: Outer<1>) {} + "#; + // assert_eq!(get_program_errors(src).len(), 0); + assert_eq!(get_program_errors(src), vec![]); + } + #[test] fn ban_mutable_globals() { // Mutable globals are only allowed in a comptime context diff --git a/test_programs/compile_success_empty/numeric_generic_arith/Nargo.toml b/test_programs/compile_success_empty/numeric_generic_arith/Nargo.toml new file mode 100644 index 00000000000..9a68b489759 --- /dev/null +++ b/test_programs/compile_success_empty/numeric_generic_arith/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "numeric_generic_arith" +type = "bin" +authors = [""] +compiler_version = ">=0.28.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_success_empty/numeric_generic_arith/src/main.nr b/test_programs/compile_success_empty/numeric_generic_arith/src/main.nr new file mode 100644 index 00000000000..360fbabae75 --- /dev/null +++ b/test_programs/compile_success_empty/numeric_generic_arith/src/main.nr @@ -0,0 +1,2 @@ +type Outer = [u8; 2 * N + 1]; +fn main(_arg: Outer<1>) {} From d02d6192cccedfde2a02871a80ea23c18cba1443 Mon Sep 17 00:00:00 2001 From: "Michael J. Klein" Date: Mon, 6 May 2024 15:01:32 -0400 Subject: [PATCH 2/3] wip propagating changes from GenericArith, pushing cloned constraints to shared vec --- .../src/hir/comptime/hir_to_ast.rs | 8 +- .../noirc_frontend/src/hir/comptime/value.rs | 2 +- .../src/hir/resolution/resolver.rs | 35 ++++----- .../noirc_frontend/src/hir/type_check/expr.rs | 8 +- .../noirc_frontend/src/hir/type_check/mod.rs | 4 +- compiler/noirc_frontend/src/hir_def/types.rs | 73 +++++++++++-------- 6 files changed, 72 insertions(+), 58 deletions(-) diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs index 1ab9c13ea25..be11de18108 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs @@ -300,9 +300,10 @@ impl Type { fn to_type_expression(&self) -> UnresolvedTypeExpression { let span = Span::default(); + // NOTE: hir_to_ast deprecated match self.follow_bindings() { - Type::Constant(length) => UnresolvedTypeExpression::Constant(length, span), - Type::NamedGeneric(_, name) => { + Type::GenericArith(GenericArith::Constant(length), _) => UnresolvedTypeExpression::Constant(length, span), + Type::GenericArith(GenericArith::NamedGeneric(_, name), _) => { let path = Path::from_single(name.as_ref().clone(), span); UnresolvedTypeExpression::Variable(path) } @@ -342,7 +343,8 @@ impl HirArrayLiteral { HirArrayLiteral::Repeated { repeated_element, length } => { let repeated_element = Box::new(repeated_element.to_ast(interner)); let length = match length { - Type::Constant(length) => { + // NOTE: hir_to_ast deprecated + Type::GenericArith(GenericArith::Constant(length), _) => { let literal = Literal::Integer((length as u128).into(), false); let kind = ExpressionKind::Literal(literal); Box::new(Expression::new(kind, span)) diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 6845c6ac5a9..b9380a68f5a 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -51,7 +51,7 @@ impl Value { Value::U32(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::ThirtyTwo), Value::U64(_) => Type::Integer(Signedness::Unsigned, IntegerBitSize::SixtyFour), Value::String(value) => { - let length = Type::Constant(value.len() as u64); + let length = Type::GenericArith(GenericArith::Constant(value.len() as u64), vec![].into()); Type::String(Box::new(length)) } Value::Function(_, typ) => return Cow::Borrowed(typ), diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 062a4d840f7..05e889589b0 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -767,7 +767,7 @@ impl<'a> Resolver<'a> { if let Some(error) = error { self.push_err(error.into()); } - Some(Type::GenericArith(GenericArith::Constant(self.eval_global_as_array_length(id, path)))) + Some(Type::GenericArith(GenericArith::Constant(self.eval_global_as_array_length(id, path)), vec![].into())) } _ => None, } @@ -787,7 +787,7 @@ impl<'a> Resolver<'a> { // 'Named'Generic is a bit of a misnomer here, we want a type variable that // wont be bound over but this one has no name since we do not currently // require users to explicitly be generic over array lengths. - Type::GenericArith(GenericArith::NamedGeneric(typevar, Rc::new("".into()))) + Type::GenericArith(GenericArith::NamedGeneric(typevar, Rc::new("".into())), vec![].into()) } Some(length) => self.convert_expression_type(length), } @@ -798,10 +798,10 @@ impl<'a> Resolver<'a> { UnresolvedTypeExpression::Variable(path) => { self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); - Type::GenericArith(GenericArith::Constant(0)) + Type::GenericArith(GenericArith::Constant(0), vec![].into()) }) } - UnresolvedTypeExpression::Constant(int, _) => Type::GenericArith(GenericArith::Constant(int)), + UnresolvedTypeExpression::Constant(int, _) => Type::GenericArith(GenericArith::Constant(int), vec![].into()), UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, op_span) => { let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); let lhs = self.convert_expression_type(*lhs); @@ -812,10 +812,10 @@ impl<'a> Resolver<'a> { // Type::Constant(op.function()(lhs, rhs)) // } - (Type::GenericArith(rhs), Type::GenericArith(lhs)) => { + (Type::GenericArith(rhs, _), Type::GenericArith(lhs, _)) => { match (lhs, rhs) { (GenericArith::Constant(lhs), GenericArith::Constant(rhs)) => { - Type::GenericArith(GenericArith::Constant(op.function()(lhs, rhs))) + Type::GenericArith(GenericArith::Constant(op.function()(lhs, rhs)), vec![].into()) } (lhs, rhs) => { match op { @@ -824,25 +824,25 @@ impl<'a> Resolver<'a> { kind: GenericArithOpKind::Add, lhs: Box::new(lhs), rhs: Box::new(rhs), - }) + }, vec![].into()) } BinaryTypeOperator::Multiplication => { Type::GenericArith(GenericArith::Op { kind: GenericArithOpKind::Mul, lhs: Box::new(lhs), rhs: Box::new(rhs), - }) + }, vec![].into()) } BinaryTypeOperator::Subtraction => { Type::GenericArith(GenericArith::Op { kind: GenericArithOpKind::Sub, lhs: Box::new(lhs), rhs: Box::new(rhs), - }) + }, vec![].into()) } _ => { self.push_err(ResolverError::InvalidGenericArithOp { span: op_span }); - Type::GenericArith(GenericArith::Constant(0)) + Type::GenericArith(GenericArith::Constant(0), vec![].into()) } } } @@ -859,7 +859,7 @@ impl<'a> Resolver<'a> { // // TODO: revert before PR // dbg!("convert_expression_type", &lhs, &rhs); self.push_err(ResolverError::InvalidArrayLengthExpr { span: op_span }); - Type::GenericArith(GenericArith::Constant(0)) + Type::GenericArith(GenericArith::Constant(0), vec![].into()) } } @@ -1240,9 +1240,10 @@ impl<'a> Resolver<'a> { | Type::Unit | Type::Error | Type::TypeVariable(_, _) - // TODO + // TODO does GenericArith have relevant cases? // | Type::Constant(_) // | Type::NamedGeneric(_, _) + | Type::GenericArith(..) | Type::Code | Type::Forall(_, _) => (), @@ -1253,7 +1254,7 @@ impl<'a> Resolver<'a> { } Type::Array(length, element_type) => { - if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = length.as_ref() { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name), _) = length.as_ref() { found.insert(name.to_string(), type_variable.clone()); } Self::find_numeric_generics_in_type(element_type, found); @@ -1278,7 +1279,7 @@ impl<'a> Resolver<'a> { Type::Struct(struct_type, generics) => { for (i, generic) in generics.iter().enumerate() { - if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = generic { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name), _) = generic { if struct_type.borrow().generic_is_numeric(i) { found.insert(name.to_string(), type_variable.clone()); } @@ -1289,7 +1290,7 @@ impl<'a> Resolver<'a> { } Type::Alias(alias, generics) => { for (i, generic) in generics.iter().enumerate() { - if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = generic { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name), _) = generic { if alias.borrow().generic_is_numeric(i) { found.insert(name.to_string(), type_variable.clone()); } @@ -1300,12 +1301,12 @@ impl<'a> Resolver<'a> { } Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found), Type::String(length) => { - if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = length.as_ref() { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name), _) = length.as_ref() { found.insert(name.to_string(), type_variable.clone()); } } Type::FmtString(length, fields) => { - if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name)) = length.as_ref() { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, name), _) = length.as_ref() { found.insert(name.to_string(), type_variable.clone()); } Self::find_numeric_generics_in_type(fields, found); diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 9b40c959981..b74fccfdbfa 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -85,7 +85,7 @@ impl<'interner> TypeChecker<'interner> { HirArrayLiteral::Repeated { repeated_element, length } => { let elem_type = self.check_expression(&repeated_element); let length = match length { - Type::Constant(length) => Ok(length), + Type::GenericArith(GenericArith::Constant(length), Shared::new(vec![])) => Ok(length), other => Err(Box::new(other)), }; (length, Box::new(elem_type)) @@ -130,11 +130,11 @@ impl<'interner> TypeChecker<'interner> { HirLiteral::Bool(_) => Type::Bool, HirLiteral::Integer(_, _) => self.polymorphic_integer_or_field(), HirLiteral::Str(string) => { - let len = Type::Constant(string.len() as u64); + let len = Type::GenericArith(GenericArith::Constant(string.len() as u64), vec![].into()); Type::String(Box::new(len)) } HirLiteral::FmtStr(string, idents) => { - let len = Type::Constant(string.len() as u64); + let len = Type::GenericArith(GenericArith::Constant(string.len() as u64), vec![].into()); let types = vecmap(&idents, |elem| self.check_expression(elem)); Type::FmtString(Box::new(len), Box::new(Type::Tuple(types))) } @@ -934,7 +934,7 @@ impl<'interner> TypeChecker<'interner> { }); None } - Type::NamedGeneric(_, _) => { + Type::GenericArith(GenericArith::NamedGeneric(_, _), _) => { let func_meta = self.interner.function_meta( &self.current_function.expect("unexpected method outside a function"), ); diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index 6235fe3848d..9e027c4ac36 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -22,7 +22,7 @@ use crate::{ traits::TraitConstraint, }, node_interner::{ExprId, FuncId, GlobalId, NodeInterner}, - Type, TypeBindings, + Type, TypeBindings, GenericArith, }; use self::errors::Source; @@ -272,7 +272,7 @@ pub(crate) fn check_trait_impl_method_matches_declaration( for ((_, trait_fn_generic), (name, impl_fn_generic)) in trait_fn_meta.direct_generics.iter().zip(&meta.direct_generics) { - let arg = Type::NamedGeneric(impl_fn_generic.clone(), name.clone()); + let arg = Type::GenericArith(GenericArith::NamedGeneric(impl_fn_generic.clone(), name.clone()), vec![].into()); bindings.insert(trait_fn_generic.id(), (trait_fn_generic.clone(), arg)); } diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 8dc8795f6fe..10e3799ea2a 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -146,7 +146,7 @@ impl std::fmt::Display for GenericArith { } // TODO: relocate -#[derive(PartialEq, Eq, Clone, Hash)] +#[derive(Debug, PartialEq, Eq, Clone, Hash)] pub enum GenericArithOpKind { Mul, Add, @@ -198,7 +198,7 @@ impl Type { | Type::Function(_, _, _) | Type::MutableReference(_) | Type::Forall(_, _) - | Type::GenericArith(_) + | Type::GenericArith(..) | Type::Code | Type::Slice(_) | Type::Error => unreachable!("This type cannot exist as a parameter to main"), @@ -669,10 +669,11 @@ impl Type { ) } + // TODO: update to recurse on GenericArith fn contains_numeric_typevar(&self, target_id: TypeVariableId) -> bool { // True if the given type is a NamedGeneric with the target_id let named_generic_id_matches_target = |typ: &Type| { - if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, _)) = typ { + if let Type::GenericArith(GenericArith::NamedGeneric(type_variable, _), _) = typ { match &*type_variable.borrow() { TypeBinding::Bound(_) => { unreachable!("Named generics should not be bound until monomorphization") @@ -691,7 +692,7 @@ impl Type { | Type::Unit | Type::Error | Type::TypeVariable(_, _) - | Type::GenericArith(_) + | Type::GenericArith(..) | Type::Forall(_, _) | Type::Code => false, @@ -750,12 +751,12 @@ impl Type { | Type::Integer(_, _) | Type::Bool | Type::Unit - | Type::GenericArith(GenericArith::Constant(_)) + | Type::GenericArith(GenericArith::Constant(_), _) | Type::Error => true, Type::FmtString(_, _) | Type::TypeVariable(_, _) - | Type::GenericArith(_) + | Type::GenericArith(..) | Type::Function(_, _, _) | Type::MutableReference(_) | Type::Forall(_, _) @@ -796,8 +797,8 @@ impl Type { | Type::Bool | Type::Unit | Type::TypeVariable(_, _) - // TODO - | Type::GenericArith(_) + // TODO ensure cases valid for program input + | Type::GenericArith(..) | Type::Error => true, Type::FmtString(_, _) @@ -834,7 +835,8 @@ impl Type { pub fn generic_count(&self) -> usize { match self { Type::Forall(generics, _) => generics.len(), - Type::TypeVariable(type_variable, _) | Type::GenericArith(GenericArith::NamedGeneric(type_variable, _)) => { + // TODO recurse on GenericArith + Type::TypeVariable(type_variable, _) | Type::GenericArith(GenericArith::NamedGeneric(type_variable, _), _) => { match &*type_variable.borrow() { TypeBinding::Bound(binding) => binding.generic_count(), TypeBinding::Unbound(_) => 0, @@ -948,7 +950,7 @@ impl std::fmt::Display for Type { } Type::Unit => write!(f, "()"), Type::Error => write!(f, "error"), - Type::GenericArith(generic_arith) => write!(f, "{}", generic_arith), + Type::GenericArith(generic_arith, equiv_generics) => write!(f, "{}({:?})", generic_arith, equiv_generics), Type::Forall(typevars, typ) => { let typevars = vecmap(typevars, |var| var.id().to_string()); write!(f, "forall {}. {}", typevars.join(" "), typ) @@ -1018,7 +1020,7 @@ impl Type { let this = self.substitute(bindings).follow_bindings(); match &this { - Type::GenericArith(GenericArith::Constant(length)) if *length == target_length => { + Type::GenericArith(GenericArith::Constant(length), _) if *length == target_length => { bindings.insert(target_id, (var.clone(), this)); Ok(()) } @@ -1183,7 +1185,7 @@ impl Type { fn get_inner_type_variable(&self) -> Option> { match self { - Type::TypeVariable(var, _) | Type::GenericArith(GenericArith::NamedGeneric(var, _)) => Some(var.1.clone()), + Type::TypeVariable(var, _) | Type::GenericArith(GenericArith::NamedGeneric(var, _), _) => Some(var.1.clone()), _ => None, } } @@ -1315,20 +1317,25 @@ impl Type { Ok(()) } else { // Err(UnificationError) - (*equiv_a).push(GenericArith::NamedGeneric(binding_b, name_b)); + + // TODO: before PR: these clone()'s detach GenericArith from the equiv_generics + // ensure that they're kept up to date + included as part of the expression + equiv_a.borrow_mut().push(GenericArith::NamedGeneric(binding_b.clone(), name_b.clone())); Ok(()) } } - (GenericArith(GenericArith::Op { kind: kind_a, lhs: lhs_a, rhs: rhs_a }, equiv_a), GenericArith(GenericArith::Op { kind: kind_b, lhs: lhs_b, rhs: rhs_b }, equiv_b)) => { - - () - } - - (GenericArith(GenericArith::Op { kind, lhs, rhs }, equiv), other) | (other, GenericArith(GenericArith::Op { kind, lhs, rhs }, equiv)) => { - - () - } + // TODO + // + // (GenericArith(GenericArith::Op { kind: kind_a, lhs: lhs_a, rhs: rhs_a }, equiv_a), GenericArith(GenericArith::Op { kind: kind_b, lhs: lhs_b, rhs: rhs_b }, equiv_b)) => { + // + // () + // } + // + // (GenericArith(GenericArith::Op { kind, lhs, rhs }, equiv), other) | (other, GenericArith(GenericArith::Op { kind, lhs, rhs }, equiv)) => { + // + // () + // } (Function(params_a, ret_a, env_a), Function(params_b, ret_b, env_b)) => { if params_a.len() == params_b.len() { @@ -1453,7 +1460,7 @@ impl Type { match self { Type::TypeVariable(_, TypeVariableKind::Constant(size)) => Some(*size), Type::Array(len, _elem) => len.evaluate_to_u64(), - Type::GenericArith(GenericArith::Constant(x)) => Some(*x), + Type::GenericArith(GenericArith::Constant(x), _) => Some(*x), _ => None, } } @@ -1608,7 +1615,7 @@ impl Type { let fields = fields.substitute_helper(type_bindings, substitute_bound_typevars); Type::FmtString(Box::new(size), Box::new(fields)) } - Type::GenericArith(GenericArith::NamedGeneric(binding, _)) | Type::TypeVariable(binding, _) => { + Type::GenericArith(GenericArith::NamedGeneric(binding, _), _) | Type::TypeVariable(binding, _) => { substitute_binding(binding) } // Do not substitute_helper fields, it can lead to infinite recursion @@ -1662,7 +1669,7 @@ impl Type { Type::FieldElement | Type::Integer(_, _) | Type::Bool - | Type::GenericArith(GenericArith::Constant(_)) + | Type::GenericArith(GenericArith::Constant(_), _) | Type::Error | Type::Code | Type::Unit => self.clone(), @@ -1686,7 +1693,7 @@ impl Type { generic_args.iter().any(|arg| arg.occurs(target_id)) } Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)), - Type::GenericArith(GenericArith::NamedGeneric(binding, _)) | Type::TypeVariable(binding, _) => { + Type::GenericArith(GenericArith::NamedGeneric(binding, _), _) | Type::TypeVariable(binding, _) => { match &*binding.borrow() { TypeBinding::Bound(binding) => binding.occurs(target_id), TypeBinding::Unbound(id) => *id == target_id, @@ -1705,7 +1712,7 @@ impl Type { Type::FieldElement | Type::Integer(_, _) | Type::Bool - | Type::GenericArith(GenericArith::Constant(_)) + | Type::GenericArith(GenericArith::Constant(_), _) | Type::Error | Type::Code | Type::Unit => false, @@ -1828,7 +1835,7 @@ impl TypeVariableKind { match self { TypeVariableKind::IntegerOrField => Some(Type::default_int_or_field_type()), TypeVariableKind::Integer => Some(Type::default_int_type()), - TypeVariableKind::Constant(length) => Some(Type::GenericArith(GenericArith::Constant(*length))), + TypeVariableKind::Constant(length) => Some(Type::GenericArith(GenericArith::Constant(*length), vec![].into())), TypeVariableKind::Normal => None, } } @@ -1879,7 +1886,7 @@ impl From<&Type> for PrintableType { Type::FmtString(_, _) => unreachable!("format strings cannot be printed"), Type::Error => unreachable!(), Type::Unit => PrintableType::Unit, - Type::GenericArith(GenericArith::Constant(_)) => unreachable!(), + Type::GenericArith(GenericArith::Constant(_), _) => unreachable!(), Type::Struct(def, ref args) => { let struct_type = def.borrow(); let fields = struct_type.get_fields(args); @@ -1890,7 +1897,7 @@ impl From<&Type> for PrintableType { Type::TraitAsType(_, _, _) => unreachable!(), Type::Tuple(types) => PrintableType::Tuple { types: vecmap(types, |typ| typ.into()) }, Type::TypeVariable(_, _) => unreachable!(), - Type::GenericArith(GenericArith::NamedGeneric(..)) => unreachable!(), + Type::GenericArith(GenericArith::NamedGeneric(..), _) => unreachable!(), Type::Forall(..) => unreachable!(), Type::Function(arguments, return_type, env) => PrintableType::Function { arguments: arguments.iter().map(|arg| arg.into()).collect(), @@ -1967,7 +1974,11 @@ impl std::fmt::Debug for Type { Type::Unit => write!(f, "()"), Type::Error => write!(f, "error"), Type::GenericArith(GenericArith::NamedGeneric(binding, name), equiv) => { - (write!(f, "{}{:?}", name, binding), equiv) + if equiv.borrow().len() == 0 { + write!(f, "{}{:?}", name, binding) + } else { + write!(f, "{}{:?}({:?})", name, binding, equiv.borrow()) + } }, Type::GenericArith(GenericArith::Constant(x), equiv) => { if equiv.borrow().len() == 0 { From fb2dd44c87ed8225469364b7e4b698ab9b30765d Mon Sep 17 00:00:00 2001 From: "Michael J. Klein" Date: Tue, 7 May 2024 16:58:28 -0400 Subject: [PATCH 3/3] wip implementing substitution and monomorphization checks --- .../src/hir/comptime/hir_to_ast.rs | 8 +-- .../noirc_frontend/src/hir/comptime/value.rs | 2 +- .../noirc_frontend/src/hir/type_check/expr.rs | 5 +- compiler/noirc_frontend/src/hir_def/types.rs | 63 ++++++++++++++++++- .../src/monomorphization/errors.rs | 7 ++- .../src/monomorphization/mod.rs | 39 +++++++++--- compiler/noirc_frontend/src/node_interner.rs | 6 +- 7 files changed, 105 insertions(+), 25 deletions(-) diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs index be11de18108..671d47c0e0c 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_ast.rs @@ -11,7 +11,7 @@ use crate::ast::{ use crate::ast::{ConstrainStatement, Expression, Statement, StatementKind}; use crate::hir_def::expr::{HirArrayLiteral, HirBlockExpression, HirExpression, HirIdent}; use crate::hir_def::stmt::{HirLValue, HirPattern, HirStatement}; -use crate::hir_def::types::Type; +use crate::hir_def::types::{GenericArith, Type}; use crate::macros_api::HirLiteral; use crate::node_interner::{ExprId, NodeInterner, StmtId}; @@ -270,10 +270,6 @@ impl Type { let name = Path::from_single(name.as_ref().clone(), Span::default()); UnresolvedTypeData::TraitAsType(name, generics) } - Type::NamedGeneric(_, name) => { - let name = Path::from_single(name.as_ref().clone(), Span::default()); - UnresolvedTypeData::TraitAsType(name, Vec::new()) - } Type::Function(args, ret, env) => { let args = vecmap(args, |arg| arg.to_ast()); let ret = Box::new(ret.to_ast()); @@ -289,9 +285,9 @@ impl Type { // Since there is no UnresolvedTypeData equivalent for Type::Forall, we use // this to ignore this case since it shouldn't be needed anyway. Type::Forall(_, typ) => return typ.to_ast(), - Type::Constant(_) => panic!("Type::Constant where a type was expected: {self:?}"), Type::Code => UnresolvedTypeData::Code, Type::Error => UnresolvedTypeData::Error, + _ => unimplemented!("TODO: hir_to_ast deprecated"), }; UnresolvedType { typ, span: None } diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index b9380a68f5a..56a2a7300b0 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -10,7 +10,7 @@ use crate::{ hir_def::expr::{HirArrayLiteral, HirConstructorExpression, HirIdent, HirLambda, ImplKind}, macros_api::{HirExpression, HirLiteral, NodeInterner}, node_interner::{ExprId, FuncId}, - Shared, Type, + Shared, Type, GenericArith, }; use rustc_hash::FxHashMap as HashMap; diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index b74fccfdbfa..676eb1c89d2 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -13,7 +13,7 @@ use crate::{ types::Type, }, node_interner::{DefinitionKind, ExprId, FuncId, TraitId, TraitImplKind, TraitMethodId}, - TypeBinding, TypeBindings, TypeVariableKind, + TypeBinding, TypeBindings, TypeVariableKind, GenericArith, }; use super::{errors::TypeCheckError, TypeChecker}; @@ -84,8 +84,9 @@ impl<'interner> TypeChecker<'interner> { } HirArrayLiteral::Repeated { repeated_element, length } => { let elem_type = self.check_expression(&repeated_element); + // TODO: before PR: support other cases here? let length = match length { - Type::GenericArith(GenericArith::Constant(length), Shared::new(vec![])) => Ok(length), + Type::GenericArith(GenericArith::Constant(length), _) => Ok(length), other => Err(Box::new(other)), }; (length, Box::new(elem_type)) diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index 10e3799ea2a..9c88b53b3a6 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -1564,6 +1564,7 @@ impl Type { self.substitute_helper(type_bindings, true) } + /// This helper function only differs in the additional parameter which, if set, /// allows substitutions on already-bound type variables. This should be `false` /// for most uses, but is currently needed during monomorphization when instantiating @@ -1618,6 +1619,15 @@ impl Type { Type::GenericArith(GenericArith::NamedGeneric(binding, _), _) | Type::TypeVariable(binding, _) => { substitute_binding(binding) } + Type::GenericArith(GenericArith::Op { kind, lhs, rhs }, constraints) => { + let lhs = Self::substitute_generic_arith_helper(*lhs.clone(), type_bindings, substitute_bound_typevars); + let rhs = Self::substitute_generic_arith_helper(*rhs.clone(), type_bindings, substitute_bound_typevars); + for constraint in constraints.borrow_mut().iter_mut() { + *constraint = Self::substitute_generic_arith_helper(constraint.clone(), type_bindings, substitute_bound_typevars) + } + + Type::GenericArith(GenericArith::Op { kind: kind.clone(), lhs: Box::new(lhs), rhs: Box::new(rhs) }, constraints.clone()) + } // Do not substitute_helper fields, it can lead to infinite recursion // and we should not match fields when type checking anyway. Type::Struct(fields, args) => { @@ -1676,6 +1686,28 @@ impl Type { } } + fn substitute_generic_arith_helper( + generic_arith: GenericArith, + type_bindings: &TypeBindings, + substitute_bound_typevars: bool, + ) -> GenericArith { + // TODO: it appears that this could be incorrect when + // 1. substitute_helper is run on a NamedGeneric + // 2. it then unwraps to a non-GenericArith + // + // ^ but can that actually happen? (i.e. does the kind of the NamedGeneric prevent it?) + + let result = Type::GenericArith(generic_arith, vec![].into()).substitute_helper(type_bindings, substitute_bound_typevars); + if let Type::GenericArith(generic_arith, constraints) = result { + if !constraints.borrow().is_empty() { + unreachable!("ICE: substitute_helper not expected to add GenericArith constraints: {:?}", constraints) + } + generic_arith + } else { + unreachable!("ICE: substitute_helper not expected to return non-GenericArith when run on GenericArith: {:?}", result) + } + } + /// True if the given TypeVariableId is free anywhere within self pub fn occurs(&self, target_id: TypeVariableId) -> bool { match self { @@ -1693,12 +1725,22 @@ impl Type { generic_args.iter().any(|arg| arg.occurs(target_id)) } Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)), - Type::GenericArith(GenericArith::NamedGeneric(binding, _), _) | Type::TypeVariable(binding, _) => { + Type::TypeVariable(binding, _) => { match &*binding.borrow() { TypeBinding::Bound(binding) => binding.occurs(target_id), TypeBinding::Unbound(id) => *id == target_id, } } + + Type::GenericArith(GenericArith::NamedGeneric(binding, _), constraints) => { + + todo!("follow TypeVariable case and recurse on constraints") + } + Type::GenericArith(GenericArith::Op { lhs, rhs, .. }, constraints) => { + + todo!("recurse on lhs, rhs, constraints") + } + Type::Forall(typevars, typ) => { !typevars.iter().any(|var| var.id() == target_id) && typ.occurs(target_id) } @@ -1749,13 +1791,22 @@ impl Type { def.borrow().get_type(args).follow_bindings() } Tuple(args) => Tuple(vecmap(args, |arg| arg.follow_bindings())), - TypeVariable(var, _) | GenericArith(GenericArith::NamedGeneric(var, _), _) => { + TypeVariable(var, _) => { if let TypeBinding::Bound(typ) = &*var.borrow() { return typ.follow_bindings(); } self.clone() } + GenericArith(GenericArith::NamedGeneric(var, _), constraints) => { + + todo!("need to follow TypeVariable case for (var) and recurse on (constraints)") + } + GenericArith(GenericArith::Op { lhs, rhs, .. }, constraints) => { + + todo!("need to recurse on lhs, rhs, constraints") + } + Function(args, ret, env) => { let args = vecmap(args, |arg| arg.follow_bindings()); let ret = Box::new(ret.follow_bindings()); @@ -1887,6 +1938,7 @@ impl From<&Type> for PrintableType { Type::Error => unreachable!(), Type::Unit => PrintableType::Unit, Type::GenericArith(GenericArith::Constant(_), _) => unreachable!(), + Type::GenericArith(GenericArith::Op { .. }, _) => unreachable!("generic arithmetic operations cannot be printed"), Type::Struct(def, ref args) => { let struct_type = def.borrow(); let fields = struct_type.get_fields(args); @@ -1987,6 +2039,13 @@ impl std::fmt::Debug for Type { write!(f, "{}({:?})", x, equiv.borrow()) } }, + Type::GenericArith(GenericArith::Op { kind, lhs, rhs }, equiv) => { + if equiv.borrow().len() == 0 { + write!(f, "{} {} {}", lhs, kind, rhs) + } else { + write!(f, "({} {} {})({:?})", lhs, kind, rhs, equiv.borrow()) + } + }, Type::Forall(typevars, typ) => { let typevars = vecmap(typevars, |var| format!("{:?}", var)); write!(f, "forall {}. {:?}", typevars.join(" "), typ) diff --git a/compiler/noirc_frontend/src/monomorphization/errors.rs b/compiler/noirc_frontend/src/monomorphization/errors.rs index 3011c26cffe..7c0b438d9b2 100644 --- a/compiler/noirc_frontend/src/monomorphization/errors.rs +++ b/compiler/noirc_frontend/src/monomorphization/errors.rs @@ -1,5 +1,6 @@ use thiserror::Error; +use crate::monomorphization::ast::Type; use noirc_errors::{CustomDiagnostic, FileDiagnostic, Location}; #[derive(Debug, Error)] @@ -9,13 +10,17 @@ pub enum MonomorphizationError { #[error("Type annotations needed")] TypeAnnotationsNeeded { location: Location }, + + #[error("Type annotations needed: {expected_result} != {found_type}")] + GenericArithIncomplete { location: Location, expected_result: Type, found_type: Type }, } impl MonomorphizationError { fn location(&self) -> Location { match self { MonomorphizationError::UnknownArrayLength { location } - | MonomorphizationError::TypeAnnotationsNeeded { location } => *location, + | MonomorphizationError::TypeAnnotationsNeeded { location } + | MonomorphizationError::GenericArithIncomplete { location, .. } => *location, } } } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index b5754897d3f..483564f6768 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -19,7 +19,7 @@ use crate::{ }, node_interner::{self, DefinitionKind, NodeInterner, StmtId, TraitImplKind, TraitMethodId}, token::FunctionAttribute, - Type, TypeBinding, TypeBindings, TypeVariable, TypeVariableKind, + Type, TypeBinding, TypeBindings, TypeVariable, TypeVariableKind, GenericArith, }; use acvm::FieldElement; use iter_extended::{btree_map, try_vecmap, vecmap}; @@ -915,16 +915,31 @@ impl<'interner> Monomorphizer<'interner> { HirType::TraitAsType(..) => { unreachable!("All TraitAsType should be replaced before calling convert_type"); } - HirType::NamedGeneric(binding, _) => { - if let TypeBinding::Bound(binding) = &*binding.borrow() { - return Self::convert_type(binding, location); + + // TODO: review whether sufficient + HirType::GenericArith(GenericArith::NamedGeneric(binding, _), constraints) => { + let expected_result = if let TypeBinding::Bound(binding) = &*binding.borrow() { + Self::convert_type(&binding, location)? + } else { + // Default any remaining unbound type variables. + // This should only happen if the variable in question is unused + // and within a larger generic type. + binding.bind(HirType::default_int_or_field_type()); + ast::Type::Field + }; + + if expected_result == ast::Type::Field { + for constraint in constraints.borrow().iter() { + let constraint_as_type = HirType::GenericArith(constraint.clone(), vec![].into()); + let found_type = Self::convert_type(&constraint_as_type, location)?; + + if expected_result != found_type { + return Err(MonomorphizationError::GenericArithIncomplete { location, expected_result, found_type }) + } + } } - // Default any remaining unbound type variables. - // This should only happen if the variable in question is unused - // and within a larger generic type. - binding.bind(HirType::default_int_or_field_type()); - ast::Type::Field + expected_result } HirType::TypeVariable(binding, kind) => { @@ -983,7 +998,11 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::MutableReference(Box::new(element)) } - HirType::Forall(_, _) | HirType::Constant(_) | HirType::Error => { + HirType::GenericArith(GenericArith::Op { .. }, _) => { + return Err(MonomorphizationError::TypeAnnotationsNeeded { location }) + } + + HirType::Forall(_, _) | HirType::GenericArith(GenericArith::Constant(_), _) | HirType::Error => { unreachable!("Unexpected type {} found", typ) } HirType::Code => unreachable!("Tried to translate Code type into runtime code"), diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index 88adc7a9414..4c3093e1225 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -30,7 +30,7 @@ use crate::hir_def::{ }; use crate::token::{Attributes, SecondaryAttribute}; use crate::{ - Generics, Shared, TypeAlias, TypeBindings, TypeVariable, TypeVariableId, TypeVariableKind, + Generics, Shared, TypeAlias, TypeBindings, TypeVariable, TypeVariableId, TypeVariableKind, GenericArith, }; /// An arbitrary number to limit the recursion depth when searching for trait impls. @@ -1771,7 +1771,7 @@ fn get_type_method_key(typ: &Type) -> Option { Type::Unit => Some(Unit), Type::Tuple(_) => Some(Tuple), Type::Function(_, _, _) => Some(Function), - Type::NamedGeneric(_, _) => Some(Generic), + Type::GenericArith(GenericArith::NamedGeneric(_, _), _) => Some(Generic), Type::Code => Some(Code), Type::MutableReference(element) => get_type_method_key(element), Type::Alias(alias, _) => get_type_method_key(&alias.borrow().typ), @@ -1779,7 +1779,7 @@ fn get_type_method_key(typ: &Type) -> Option { // We do not support adding methods to these types Type::TypeVariable(_, _) | Type::Forall(_, _) - | Type::Constant(_) + | Type::GenericArith(..) | Type::Error | Type::Struct(_, _) | Type::TraitAsType(..) => None,