diff --git a/Cargo.lock b/Cargo.lock index 4b140a513f2..378148db37e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2725,6 +2725,7 @@ version = "0.30.0" dependencies = [ "acvm", "iter-extended", + "noirc_errors", "noirc_frontend", "noirc_printable_type", "num-bigint", diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 80a63f223e7..8a35a95abf5 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -21,12 +21,16 @@ use acvm::{ FieldElement, }; -use noirc_errors::debug_info::{DebugFunctions, DebugInfo, DebugTypes, DebugVariables}; +use noirc_errors::{ + debug_info::{DebugFunctions, DebugInfo, DebugTypes, DebugVariables}, + Location, +}; use noirc_frontend::ast::Visibility; use noirc_frontend::{ hir_def::{function::FunctionSignature, types::Type as HirType}, monomorphization::ast::Program, + node_interner::NodeInterner, }; use tracing::{span, Level}; @@ -278,7 +282,9 @@ fn split_public_and_private_inputs( .0 .iter() .map(|(_, typ, visibility)| { - let num_field_elements_needed = typ.field_count() as usize; + let dummy_interner = NodeInterner::default(); + let num_field_elements_needed = + typ.field_count(&Location::dummy(), &dummy_interner) as usize; let witnesses = input_witnesses[idx..idx + num_field_elements_needed].to_vec(); idx += num_field_elements_needed; (visibility, witnesses) diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs index 5f0660f5a79..c94c0a6415e 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs @@ -3,8 +3,10 @@ use std::rc::Rc; use crate::ssa::ir::{types::Type, value::ValueId}; use acvm::FieldElement; use fxhash::FxHashMap as HashMap; +use noirc_errors::Location; use noirc_frontend::ast; use noirc_frontend::hir_def::function::FunctionSignature; +use noirc_frontend::node_interner::NodeInterner; use super::FunctionBuilder; @@ -37,7 +39,8 @@ impl DataBusBuilder { ast::Visibility::Public | ast::Visibility::Private => false, ast::Visibility::DataBus => true, }; - let len = param.1.field_count() as usize; + let dummy_interner = NodeInterner::default(); + let len = param.1.field_count(&Location::dummy(), &dummy_interner) as usize; params_is_databus.extend(vec![is_databus; len]); } params_is_databus diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index 594c0690033..4be6af056d3 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -801,6 +801,11 @@ impl<'context> Elaborator<'context> { } } } + Type::GenericArith(_, generics) => { + for generic in generics { + Self::find_numeric_generics_in_type(generic, found); + } + } Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found), Type::String(length) => { if let Type::NamedGeneric(type_variable, name) = length.as_ref() { diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index 955b4af327a..f2ca1a3771e 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::rc::Rc; use acvm::acir::AcirField; @@ -29,7 +30,10 @@ use crate::{ HirExpression, HirLiteral, HirStatement, Path, PathKind, SecondaryAttribute, Signedness, UnaryOp, UnresolvedType, UnresolvedTypeData, }, - node_interner::{DefinitionKind, ExprId, GlobalId, TraitId, TraitImplKind, TraitMethodId}, + node_interner::{ + generic_arith::{ArithExpr, ArithOpKind}, + DefinitionKind, ExprId, GlobalId, TraitId, TraitImplKind, TraitMethodId, + }, Generics, Type, TypeBinding, TypeVariable, TypeVariableKind, }; @@ -272,32 +276,98 @@ impl<'context> Elaborator<'context> { pub(super) fn convert_expression_type(&mut self, length: UnresolvedTypeExpression) -> Type { match length { UnresolvedTypeExpression::Variable(path) => { - self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { - self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); - Type::Constant(0) - }) + let var_or_constant = + self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { + self.push_err(ResolverError::NoSuchNumericTypeVariable { + path: path.clone(), + }); + Type::Constant(0) + }); + if let Type::NamedGeneric(ref binding, ref name) = var_or_constant { + // we intern variables so that they can be resolved during trait resolution + let arith_expr = + ArithExpr::Variable(binding.clone(), name.clone(), Default::default()); + let location = Location { span: path.span, file: self.file }; + let _ = self.interner.push_arith_expression(arith_expr, location); + } + var_or_constant + } + UnresolvedTypeExpression::Constant(int, span) => { + // we intern constants so that they can be resolved during trait resolution + let arith_expr = ArithExpr::Constant(int); + let location = Location { span, file: self.file }; + let _ = self.interner.push_arith_expression(arith_expr, location); + Type::Constant(int) } - UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int), - UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { + UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, op_span) => { let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); let lhs = self.convert_expression_type(*lhs); let rhs = self.convert_expression_type(*rhs); + let span = if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span }; + let (lhs, lhs_generics) = self.convert_type_to_arith_expr(lhs, span); + let (rhs, rhs_generics) = self.convert_type_to_arith_expr(rhs, span); + match (lhs, rhs) { - (Type::Constant(lhs), Type::Constant(rhs)) => { - Type::Constant(op.function()(lhs, rhs)) + (ArithExpr::Constant(lhs), ArithExpr::Constant(rhs)) => { + let int = op.function()(lhs, rhs); + let arith_expr = ArithExpr::Constant(int); + let op_location = Location { span: op_span, file: self.file }; + let _ = self.interner.push_arith_expression(arith_expr, op_location); + assert!(lhs_generics.is_empty(), "constant generics expected to be empty"); + assert!(rhs_generics.is_empty(), "constant generics expected to be empty"); + Type::Constant(int) } - (lhs, _) => { - let span = - if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span }; - self.push_err(ResolverError::InvalidArrayLengthExpr { span }); - Type::Constant(0) + + (lhs, rhs) => { + let kind = + ArithOpKind::from_binary_type_operator(op).unwrap_or_else(|| { + self.push_err(ResolverError::InvalidGenericArithOp { + span: op_span, + }); + // return a valid ArithOpKind when erroring + ArithOpKind::Add + }); + + // offset GenericIndex's in rhs to prevent overlap + let rhs = rhs.offset_generic_indices(lhs_generics.len()); + + let arith_expr = + ArithExpr::Op { kind, lhs: Box::new(lhs), rhs: Box::new(rhs) }; + let op_location = Location { span: op_span, file: self.file }; + let new_id = self.interner.push_arith_expression(arith_expr, op_location); + let new_generics = lhs_generics + .into_iter() + .chain(rhs_generics) + .collect::>() + .into_iter() + .collect::>(); + + Type::GenericArith(new_id, new_generics) } } } } } + fn convert_type_to_arith_expr(&mut self, typ: Type, span: Span) -> (ArithExpr, Vec) { + match typ { + Type::Constant(value) => (ArithExpr::Constant(value), vec![]), + Type::NamedGeneric(typevar, name) => ( + ArithExpr::Variable(typevar.clone(), name.clone(), Default::default()), + vec![Type::NamedGeneric(typevar, name)], + ), + Type::GenericArith(arith_id, generics) => { + let (arith_expr, _location) = self.interner.get_arith_expression(arith_id); + (arith_expr.clone(), generics) + } + _ => { + self.push_err(ResolverError::InvalidArrayLengthExpr { span }); + (ArithExpr::Constant(0), vec![]) + } + } + } + // this resolves Self::some_static_method, inside an impl block (where we don't have a concrete self_type) // // Returns the trait method, trait constraint, and whether the impl is assumed to exist by a where clause or not @@ -547,7 +617,7 @@ impl<'context> Elaborator<'context> { make_error: impl FnOnce() -> TypeCheckError, ) { let mut errors = Vec::new(); - actual.unify(expected, &mut errors, make_error); + actual.unify(expected, &self.interner.arith_constraints, &mut errors, make_error); self.errors.extend(errors.into_iter().map(|error| (error.into(), self.file))); } diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index 34cecf0ece4..69f65fe1578 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -1,4 +1,7 @@ -use crate::{hir::def_collector::dc_crate::CompilationError, Type}; +use crate::{ + hir::def_collector::dc_crate::CompilationError, node_interner::generic_arith::ArithExprError, + Type, +}; use acvm::{acir::AcirField, FieldElement}; use noirc_errors::{CustomDiagnostic, Location}; @@ -23,7 +26,7 @@ pub enum InterpreterError { NonArrayIndexed { value: Value, location: Location }, NonIntegerUsedAsIndex { value: Value, location: Location }, NonIntegerIntegerLiteral { typ: Type, location: Location }, - NonIntegerArrayLength { typ: Type, location: Location }, + NonIntegerArrayLength { typ: Type, location: Location, arith_expr_error: ArithExprError }, NonNumericCasted { value: Value, location: Location }, IndexOutOfBounds { index: usize, length: usize, location: Location }, ExpectedStructToHaveField { value: Value, field_name: String, location: Location }, @@ -196,9 +199,9 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { let secondary = "This is likely a bug".into(); CustomDiagnostic::simple_error(msg, secondary, location.span) } - InterpreterError::NonIntegerArrayLength { typ, location } => { + InterpreterError::NonIntegerArrayLength { typ, location, arith_expr_error } => { let msg = format!("Non-integer array length: `{typ}`"); - let secondary = "Array lengths must be integers".into(); + let secondary = format!("Array lengths must be integers\n{}", arith_expr_error); CustomDiagnostic::simple_error(msg, secondary, location.span) } InterpreterError::NonNumericCasted { value, location } => { diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index c0aeb910f22..d61e80f5768 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -21,7 +21,9 @@ use crate::{ }, }, macros_api::{HirExpression, HirLiteral, HirStatement, NodeInterner}, - node_interner::{DefinitionId, DefinitionKind, ExprId, FuncId, StmtId}, + node_interner::{ + generic_arith::ArithExprError, DefinitionId, DefinitionKind, ExprId, FuncId, StmtId, + }, Shared, Type, TypeBinding, TypeBindings, TypeVariableKind, }; @@ -284,13 +286,16 @@ impl<'a> Interpreter<'a> { Err(InterpreterError::NonComptimeVarReferenced { name, location }) } - fn type_check(&self, typ: &Type, value: &Value, location: Location) -> IResult<()> { + fn type_check(&mut self, typ: &Type, value: &Value, location: Location) -> IResult<()> { let typ = typ.follow_bindings(); let value_type = value.get_type(); - typ.try_unify(&value_type, &mut TypeBindings::new()).map_err(|_| { - InterpreterError::TypeMismatch { expected: typ, value: value.clone(), location } - }) + typ.try_unify(&value_type, &mut TypeBindings::new(), &self.interner.arith_constraints) + .map_err(|_| InterpreterError::TypeMismatch { + expected: typ, + value: value.clone(), + location, + }) } /// Evaluate an expression and return the result @@ -347,17 +352,30 @@ impl<'a> Interpreter<'a> { } DefinitionKind::GenericType(type_variable) => { let value = match &*type_variable.borrow() { - TypeBinding::Unbound(_) => None, - TypeBinding::Bound(binding) => binding.evaluate_to_u64(), + TypeBinding::Unbound(var_id) => Err(ArithExprError::UnboundVariable { + binding: type_variable.clone(), + name: format!("#type_variable_id_{}_{:?}", var_id, ident), + }), + TypeBinding::Bound(binding) => { + binding.evaluate_to_u64(&definition.location, self.interner) + } }; - if let Some(value) = value { - let typ = self.interner.id_type(id); - self.evaluate_integer((value as u128).into(), false, id) - } else { - let location = self.interner.expr_location(&id); - let typ = Type::TypeVariable(type_variable.clone(), TypeVariableKind::Normal); - Err(InterpreterError::NonIntegerArrayLength { typ, location }) + match value { + Ok(value) => { + let typ = self.interner.id_type(id); + self.evaluate_integer((value as u128).into(), false, id) + } + Err(arith_expr_error) => { + let location = self.interner.expr_location(&id); + let typ = + Type::TypeVariable(type_variable.clone(), TypeVariableKind::Normal); + Err(InterpreterError::NonIntegerArrayLength { + typ, + location, + arith_expr_error, + }) + } } } } @@ -500,13 +518,21 @@ impl<'a> Interpreter<'a> { } HirArrayLiteral::Repeated { repeated_element, length } => { let element = self.evaluate(repeated_element)?; + let location = self.interner.expr_location(&id); - if let Some(length) = length.evaluate_to_u64() { - let elements = (0..length).map(|_| element.clone()).collect(); - Ok(Value::Array(elements, typ)) - } else { - let location = self.interner.expr_location(&id); - Err(InterpreterError::NonIntegerArrayLength { typ: length, location }) + match length.evaluate_to_u64(&location, self.interner) { + Ok(length) => { + let elements = (0..length).map(|_| element.clone()).collect(); + Ok(Value::Array(elements, typ)) + } + Err(arith_expr_error) => { + let location = self.interner.expr_location(&id); + Err(InterpreterError::NonIntegerArrayLength { + typ: length, + location, + arith_expr_error, + }) + } } } } diff --git a/compiler/noirc_frontend/src/hir/mod.rs b/compiler/noirc_frontend/src/hir/mod.rs index 55dc22d6c5d..73eb7a3e88c 100644 --- a/compiler/noirc_frontend/src/hir/mod.rs +++ b/compiler/noirc_frontend/src/hir/mod.rs @@ -8,7 +8,7 @@ pub mod type_check; use crate::debug::DebugInstrumenter; use crate::graph::{CrateGraph, CrateId}; use crate::hir_def::function::FuncMeta; -use crate::node_interner::{FuncId, NodeInterner, StructId}; +use crate::node_interner::{ArithConstraints, FuncId, NodeInterner, StructId}; use crate::parser::ParserError; use crate::ParsedModule; use def_map::{Contract, CrateDefMap}; @@ -43,6 +43,8 @@ pub struct Context<'file_manager, 'parsed_files> { // Same as the file manager, we take ownership of the parsed files in the WASM context. // Parsed files is also read only. pub parsed_files: Cow<'parsed_files, ParsedFiles>, + + pub arith_constraints: ArithConstraints, } #[derive(Debug, Copy, Clone)] @@ -62,6 +64,7 @@ impl Context<'_, '_> { file_manager: Cow::Owned(file_manager), debug_instrumenter: DebugInstrumenter::default(), parsed_files: Cow::Owned(parsed_files), + arith_constraints: Vec::new().into(), } } @@ -77,6 +80,7 @@ impl Context<'_, '_> { file_manager: Cow::Borrowed(file_manager), debug_instrumenter: DebugInstrumenter::default(), parsed_files: Cow::Borrowed(parsed_files), + arith_constraints: Vec::new().into(), } } diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index fa4ea96316a..8c49069d382 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -46,6 +46,12 @@ pub enum ResolverError { MissingRhsExpr { name: String, span: Span }, #[error("Expression invalid in an array length context")] InvalidArrayLengthExpr { span: Span }, + #[error("Expression invalid for generic arithmetic: only '+', '-', and '*' are allowed.")] + InvalidGenericArithOp { span: Span }, + #[error("impl's are not allowed on generic arithmetic")] + ImplOnGenericArith { span: Span }, + #[error("impl's are not allowed to have generic arithmetic arguments")] + ImplWithGenericArith { span: Span }, #[error("Integer too large to be evaluated in an array length context")] IntegerTooLarge { span: Span }, #[error("No global or generic type parameter found with the given name")] @@ -236,6 +242,21 @@ impl<'a> From<&'a ResolverError> for Diagnostic { "Array-length expressions can only have simple integer operations and any variables used must be global constants".into(), *span, ), + ResolverError::InvalidGenericArithOp { span } => Diagnostic::simple_error( + "Expression invalid for generic arithmetic with variables: only '+', '-', and '*' are allowed.".into(), + "Generic expressions can only have constants, variables, and simple integer operations: '+', '-', and '*'".into(), + *span, + ), + ResolverError::ImplOnGenericArith { span } => Diagnostic::simple_error( + "impl's are not allowed on generic arithmetic".into(), + "For example, impl Foo for Bar contains the generic arithmetic \"N + 1\", which is not allowed".into(), + *span, + ), + ResolverError::ImplWithGenericArith { span } => Diagnostic::simple_error( + "impl's are not allowed to have generic arithmetic arguments".into(), + "For example, impl Foo for Bar contains the generic arithmetic \"N + 1\", which is not allowed".into(), + *span, + ), ResolverError::IntegerTooLarge { span } => Diagnostic::simple_error( "Integer too large to be evaluated to an array-length".into(), "Array-lengths may be a maximum size of usize::MAX, including intermediate calculations".into(), diff --git a/compiler/noirc_frontend/src/hir/resolution/resolver.rs b/compiler/noirc_frontend/src/hir/resolution/resolver.rs index 35ba964c499..33649c3c1ff 100644 --- a/compiler/noirc_frontend/src/hir/resolution/resolver.rs +++ b/compiler/noirc_frontend/src/hir/resolution/resolver.rs @@ -39,6 +39,7 @@ use crate::hir::def_map::{ModuleDefId, TryFromModuleDefId, MAIN_FUNCTION}; use crate::hir::{def_map::CrateDefMap, resolution::path_resolver::PathResolver}; use crate::hir_def::stmt::{HirAssignStatement, HirForStatement, HirLValue, HirPattern}; use crate::node_interner::{ + generic_arith::{ArithExpr, ArithOpKind}, DefinitionId, DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, NodeInterner, StmtId, StructId, TraitId, TraitImplId, TraitMethodId, TypeAliasId, }; @@ -738,6 +739,7 @@ impl<'a> Resolver<'a> { } } + // returns a Type::NamedGeneric or a Type::Constant fn lookup_generic_or_global_type(&mut self, path: &Path) -> Option { if path.segments.len() == 1 { let name = &path.last_segment().0.contents; @@ -765,32 +767,98 @@ impl<'a> Resolver<'a> { fn convert_expression_type(&mut self, length: UnresolvedTypeExpression) -> Type { match length { UnresolvedTypeExpression::Variable(path) => { - self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { - self.push_err(ResolverError::NoSuchNumericTypeVariable { path }); - Type::Constant(0) - }) + let var_or_constant = + self.lookup_generic_or_global_type(&path).unwrap_or_else(|| { + self.push_err(ResolverError::NoSuchNumericTypeVariable { + path: path.clone(), + }); + Type::Constant(0) + }); + if let Type::NamedGeneric(ref binding, ref name) = var_or_constant { + // we intern variables so that they can be resolved during trait resolution + let arith_expr = + ArithExpr::Variable(binding.clone(), name.clone(), Default::default()); + let location = Location { span: path.span, file: self.file }; + let _ = self.interner.push_arith_expression(arith_expr, location); + } + var_or_constant } - UnresolvedTypeExpression::Constant(int, _) => Type::Constant(int), - UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { + UnresolvedTypeExpression::Constant(int, span) => { + // we intern constants so that they can be resolved during trait resolution + let arith_expr = ArithExpr::Constant(int); + let location = Location { span, file: self.file }; + let _ = self.interner.push_arith_expression(arith_expr, location); + Type::Constant(int) + } + UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, op_span) => { let (lhs_span, rhs_span) = (lhs.span(), rhs.span()); let lhs = self.convert_expression_type(*lhs); let rhs = self.convert_expression_type(*rhs); + let span = if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span }; + let (lhs, lhs_generics) = self.convert_type_to_arith_expr(lhs, span); + let (rhs, rhs_generics) = self.convert_type_to_arith_expr(rhs, span); + match (lhs, rhs) { - (Type::Constant(lhs), Type::Constant(rhs)) => { - Type::Constant(op.function()(lhs, rhs)) + (ArithExpr::Constant(lhs), ArithExpr::Constant(rhs)) => { + let int = op.function()(lhs, rhs); + let arith_expr = ArithExpr::Constant(int); + let op_location = Location { span: op_span, file: self.file }; + let _ = self.interner.push_arith_expression(arith_expr, op_location); + assert!(lhs_generics.is_empty(), "constant generics expected to be empty"); + assert!(rhs_generics.is_empty(), "constant generics expected to be empty"); + Type::Constant(int) } - (lhs, _) => { - let span = - if !matches!(lhs, Type::Constant(_)) { lhs_span } else { rhs_span }; - self.push_err(ResolverError::InvalidArrayLengthExpr { span }); - Type::Constant(0) + + (lhs, rhs) => { + let kind = + ArithOpKind::from_binary_type_operator(op).unwrap_or_else(|| { + self.push_err(ResolverError::InvalidGenericArithOp { + span: op_span, + }); + // return a valid ArithOpKind when erroring + ArithOpKind::Add + }); + + // offset GenericIndex's in rhs to prevent overlap + let rhs = rhs.offset_generic_indices(lhs_generics.len()); + + let arith_expr = + ArithExpr::Op { kind, lhs: Box::new(lhs), rhs: Box::new(rhs) }; + let op_location = Location { span: op_span, file: self.file }; + let new_id = self.interner.push_arith_expression(arith_expr, op_location); + let new_generics = lhs_generics + .into_iter() + .chain(rhs_generics) + .collect::>() + .into_iter() + .collect::>(); + + Type::GenericArith(new_id, new_generics) } } } } } + fn convert_type_to_arith_expr(&mut self, typ: Type, span: Span) -> (ArithExpr, Vec) { + match typ { + Type::Constant(value) => (ArithExpr::Constant(value), vec![]), + Type::NamedGeneric(typevar, name) => ( + ArithExpr::Variable(typevar.clone(), name.clone(), Default::default()), + vec![Type::NamedGeneric(typevar, name)], + ), + Type::GenericArith(arith_id, generics) => { + let (arith_expr, _location) = self.interner.get_arith_expression(arith_id); + (arith_expr.clone(), generics) + } + _ => { + self.push_err(ResolverError::InvalidArrayLengthExpr { span }); + (ArithExpr::Constant(0), vec![]) + } + } + } + fn get_ident_from_path(&mut self, path: Path) -> (HirIdent, usize) { let location = Location::new(path.span(), self.file); @@ -1207,6 +1275,13 @@ impl<'a> Resolver<'a> { } } } + Type::GenericArith(_, generics) => { + for generic in generics { + if let Type::NamedGeneric(type_variable, name) = generic { + found.insert(name.to_string(), type_variable.clone()); + } + } + } Type::MutableReference(element) => Self::find_numeric_generics_in_type(element, found), Type::String(length) => { if let Type::NamedGeneric(type_variable, name) = length.as_ref() { diff --git a/compiler/noirc_frontend/src/hir/resolution/traits.rs b/compiler/noirc_frontend/src/hir/resolution/traits.rs index 3d355fd4447..57f75cc4099 100644 --- a/compiler/noirc_frontend/src/hir/resolution/traits.rs +++ b/compiler/noirc_frontend/src/hir/resolution/traits.rs @@ -2,7 +2,7 @@ use std::collections::{BTreeMap, HashSet}; use fm::FileId; use iter_extended::vecmap; -use noirc_errors::Location; +use noirc_errors::{Location, Span}; use crate::ast::{ItemVisibility, Path, TraitItem}; use crate::{ @@ -13,6 +13,7 @@ use crate::{ errors::{DefCollectorErrorKind, DuplicateType}, }, def_map::{CrateDefMap, ModuleDefId, ModuleId}, + resolution::errors::ResolverError, Context, }, hir_def::traits::{TraitConstant, TraitFunction, TraitImpl, TraitType}, @@ -395,12 +396,12 @@ pub(crate) fn resolve_trait_impls( let mut methods = Vec::<(FileId, FuncId)>::new(); for trait_impl in traits { - let unresolved_type = trait_impl.object_type; + let unresolved_type = &trait_impl.object_type; let local_mod_id = trait_impl.module_id; let module_id = ModuleId { krate: crate_id, local_id: local_mod_id }; let path_resolver = StandardPathResolver::new(module_id); - let self_type_span = unresolved_type.span; + let self_type_span = unresolved_type.span.unwrap_or_else(|| trait_impl.trait_path.span()); let mut resolver = Resolver::new(interner, &path_resolver, &context.def_maps, trait_impl.file_id); @@ -410,6 +411,15 @@ pub(crate) fn resolve_trait_impls( vecmap(&trait_impl.trait_generics, |generic| resolver.resolve_type(generic.clone())); let self_type = resolver.resolve_type(unresolved_type.clone()); + + prevent_generic_arith_in_self_type(&self_type, &trait_impl, self_type_span, errors); + prevent_generic_arith_in_trait_generics( + &trait_generics, + &trait_impl, + self_type_span, + errors, + ); + let impl_generics = resolver.get_generics().to_vec(); let impl_id = interner.next_trait_impl_id(); @@ -432,8 +442,7 @@ pub(crate) fn resolve_trait_impls( } if matches!(self_type, Type::MutableReference(_)) { - let span = self_type_span.unwrap_or_else(|| trait_impl.trait_path.span()); - let error = DefCollectorErrorKind::MutableReferenceInTraitImpl { span }; + let error = DefCollectorErrorKind::MutableReferenceInTraitImpl { span: self_type_span }; errors.push((error.into(), trait_impl.file_id)); } @@ -475,7 +484,7 @@ pub(crate) fn resolve_trait_impls( ) { let error = DefCollectorErrorKind::OverlappingImpl { typ: self_type.clone(), - span: self_type_span.unwrap_or_else(|| trait_impl.trait_path.span()), + span: self_type_span, }; errors.push((error.into(), trait_impl.file_id)); @@ -491,3 +500,36 @@ pub(crate) fn resolve_trait_impls( methods } + +fn prevent_generic_arith_in_self_type( + self_type: &Type, + trait_impl: &UnresolvedTraitImpl, + self_type_span: Span, + errors: &mut Vec<(CompilationError, FileId)>, +) { + // prevent GenericArith in self_type and trait_generics + if self_type.contains_generic_arith() { + let error = ResolverError::ImplOnGenericArith { span: self_type_span }; + errors.push((error.into(), trait_impl.file_id)); + } +} + +fn prevent_generic_arith_in_trait_generics( + trait_generics: &[Type], + trait_impl: &UnresolvedTraitImpl, + self_type_span: Span, + errors: &mut Vec<(CompilationError, FileId)>, +) { + if trait_generics.iter().any(|trait_generic| trait_generic.contains_generic_arith()) { + let impl_type_span = + trait_impl.trait_generics.iter().fold(self_type_span, |acc, trait_generic| { + if let Some(trait_generic_span) = trait_generic.span { + acc.merge(trait_generic_span) + } else { + acc + } + }); + let error = ResolverError::ImplWithGenericArith { span: impl_type_span }; + errors.push((error.into(), trait_impl.file_id)); + } +} diff --git a/compiler/noirc_frontend/src/hir/type_check/expr.rs b/compiler/noirc_frontend/src/hir/type_check/expr.rs index 6504ead178f..25df1e6524e 100644 --- a/compiler/noirc_frontend/src/hir/type_check/expr.rs +++ b/compiler/noirc_frontend/src/hir/type_check/expr.rs @@ -68,17 +68,22 @@ impl<'interner> TypeChecker<'interner> { for (index, elem_type) in elem_types.iter().enumerate().skip(1) { let location = self.interner.expr_location(&arr[index]); - elem_type.unify(&first_elem_type, &mut self.errors, || { - TypeCheckError::NonHomogeneousArray { - first_span: self.interner.expr_location(&arr[0]).span, - first_type: first_elem_type.to_string(), - first_index: index, - second_span: location.span, - second_type: elem_type.to_string(), - second_index: index + 1, - } - .add_context("elements in an array must have the same type") - }); + elem_type.unify( + &first_elem_type, + &self.interner.arith_constraints, + &mut self.errors, + || { + TypeCheckError::NonHomogeneousArray { + first_span: self.interner.expr_location(&arr[0]).span, + first_type: first_elem_type.to_string(), + first_index: index, + second_span: location.span, + second_type: elem_type.to_string(), + second_index: index + 1, + } + .add_context("elements in an array must have the same type") + }, + ); } (Ok(arr.len() as u64), Box::new(first_elem_type.clone())) @@ -663,13 +668,16 @@ impl<'interner> TypeChecker<'interner> { let index_type = self.check_expression(&index_expr.index); let span = self.interner.expr_span(&index_expr.index); - index_type.unify(&self.polymorphic_integer_or_field(), &mut self.errors, || { - TypeCheckError::TypeMismatch { + index_type.unify( + &self.polymorphic_integer_or_field(), + &self.interner.arith_constraints, + &mut self.errors, + || TypeCheckError::TypeMismatch { expected_typ: "an integer".to_owned(), expr_typ: index_type.to_string(), expr_span: span, - } - }); + }, + ); // When writing `a[i]`, if `a : &mut ...` then automatically dereference `a` as many // times as needed to get the underlying array. @@ -1300,10 +1308,12 @@ impl<'interner> TypeChecker<'interner> { span: Span, ) -> Type { let mut unify = |expected| { - rhs_type.unify(&expected, &mut self.errors, || TypeCheckError::TypeMismatch { - expr_typ: rhs_type.to_string(), - expected_typ: expected.to_string(), - expr_span: span, + rhs_type.unify(&expected, &self.interner.arith_constraints, &mut self.errors, || { + TypeCheckError::TypeMismatch { + expr_typ: rhs_type.to_string(), + expected_typ: expected.to_string(), + expr_span: span, + } }); expected }; @@ -1315,10 +1325,12 @@ impl<'interner> TypeChecker<'interner> { .push(TypeCheckError::InvalidUnaryOp { kind: rhs_type.to_string(), span }); } let expected = self.polymorphic_integer_or_field(); - rhs_type.unify(&expected, &mut self.errors, || TypeCheckError::InvalidUnaryOp { - kind: rhs_type.to_string(), - span, - }); + rhs_type.unify( + &expected, + &self.interner.arith_constraints, + &mut self.errors, + || TypeCheckError::InvalidUnaryOp { kind: rhs_type.to_string(), span }, + ); expected } crate::ast::UnaryOp::Not => { diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index 65a3186b004..15ad70a15ab 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -12,16 +12,15 @@ mod expr; mod stmt; pub use errors::TypeCheckError; -use noirc_errors::Span; use crate::{ hir_def::{ expr::HirExpression, - function::{Param, Parameters}, + function::{FuncMeta, Param}, stmt::HirStatement, traits::TraitConstraint, }, - node_interner::{ExprId, FuncId, GlobalId, NodeInterner}, + node_interner::{ArithConstraints, ExprId, FuncId, GlobalId, NodeInterner}, Type, TypeBindings, }; @@ -294,9 +293,9 @@ pub(crate) fn check_trait_impl_method_matches_declaration( &declaration_type, definition_type, method_name, - &meta.parameters, - meta.name.location.span, + meta, &trait_info.name.0.contents, + &interner.arith_constraints, &mut errors, ); } @@ -308,11 +307,14 @@ fn check_function_type_matches_expected_type( expected: &Type, actual: &Type, method_name: &str, - actual_parameters: &Parameters, - span: Span, + func_meta: &FuncMeta, trait_name: &str, + arith_constraints: &ArithConstraints, errors: &mut Vec, ) { + let actual_parameters = &func_meta.parameters; + let span = func_meta.name.location.span; + let mut bindings = TypeBindings::new(); // Shouldn't need to unify envs, they should always be equal since they're both free functions if let (Type::Function(params_a, ret_a, _env_a), Type::Function(params_b, ret_b, _env_b)) = @@ -320,7 +322,7 @@ fn check_function_type_matches_expected_type( { if params_a.len() == params_b.len() { for (i, (a, b)) in params_a.iter().zip(params_b.iter()).enumerate() { - if a.try_unify(b, &mut bindings).is_err() { + if a.try_unify(b, &mut bindings, arith_constraints).is_err() { errors.push(TypeCheckError::TraitMethodParameterTypeMismatch { method_name: method_name.to_string(), expected_typ: a.to_string(), @@ -331,7 +333,7 @@ fn check_function_type_matches_expected_type( } } - if ret_b.try_unify(ret_a, &mut bindings).is_err() { + if ret_b.try_unify(ret_a, &mut bindings, arith_constraints).is_err() { errors.push(TypeCheckError::TypeMismatch { expected_typ: ret_a.to_string(), expr_typ: ret_b.to_string(), @@ -397,7 +399,7 @@ impl<'interner> TypeChecker<'interner> { expected: &Type, make_error: impl FnOnce() -> TypeCheckError, ) { - actual.unify(expected, &mut self.errors, make_error); + actual.unify(expected, &self.interner.arith_constraints, &mut self.errors, make_error); } /// Wrapper of Type::unify_with_coercions using self.errors @@ -564,6 +566,7 @@ pub mod test { let errors = super::type_check_func(&mut interner, func_id); assert!(errors.is_empty()); + assert!(interner.arith_constraints.borrow().is_empty()); } #[test] @@ -764,7 +767,6 @@ pub mod test { // Type check section let mut errors = Vec::new(); - for function in func_ids.values() { errors.extend(super::type_check_func(&mut interner, *function)); } @@ -777,6 +779,7 @@ pub mod test { errors.len(), errors ); + assert!(interner.arith_constraints.borrow().is_empty()); (interner, main_id) } diff --git a/compiler/noirc_frontend/src/hir/type_check/stmt.rs b/compiler/noirc_frontend/src/hir/type_check/stmt.rs index 3a570922c81..6c318957b43 100644 --- a/compiler/noirc_frontend/src/hir/type_check/stmt.rs +++ b/compiler/noirc_frontend/src/hir/type_check/stmt.rs @@ -235,13 +235,16 @@ impl<'interner> TypeChecker<'interner> { let expr_span = self.interner.expr_span(index); let location = *location; - index_type.unify(&self.polymorphic_integer_or_field(), &mut self.errors, || { - TypeCheckError::TypeMismatch { + index_type.unify( + &self.polymorphic_integer_or_field(), + &self.interner.arith_constraints, + &mut self.errors, + || TypeCheckError::TypeMismatch { expected_typ: "an integer".to_owned(), expr_typ: index_type.to_string(), expr_span, - } - }); + }, + ); let (mut lvalue_type, mut lvalue, mut mutable) = self.check_lvalue(array, assign_span); diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index cf9aafbb308..9a92222d91f 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -8,7 +8,12 @@ use std::{ use crate::{ ast::IntegerBitSize, hir::type_check::TypeCheckError, - node_interner::{ExprId, NodeInterner, TraitId, TypeAliasId}, + node_interner::{ + generic_arith::{ + ArithConstraint, ArithConstraints, ArithExpr, ArithExprError, ArithId, NeedsInterning, + }, + ExprId, NodeInterner, TraitId, TypeAliasId, + }, }; use iter_extended::vecmap; use noirc_errors::{Location, Span}; @@ -100,6 +105,8 @@ pub enum Type { /// will be and thus needs the full TypeVariable link. Forall(Generics, Box), + GenericArith(ArithId, /*generics:*/ Vec), + /// A type-level integer. Included to let an Array's size type variable /// bind to an integer without special checks to bind it to a non-type. Constant(u64), @@ -116,28 +123,32 @@ pub enum Type { impl Type { /// Returns the number of field elements required to represent the type once encoded. - pub fn field_count(&self) -> u32 { + pub fn field_count(&self, location: &Location, interner: &NodeInterner) -> u32 { match self { Type::FieldElement | Type::Integer { .. } | Type::Bool => 1, Type::Array(size, typ) => { let length = size - .evaluate_to_u64() + .evaluate_to_u64(location, interner) .expect("Cannot have variable sized arrays as a parameter to main"); let typ = typ.as_ref(); - (length as u32) * typ.field_count() + (length as u32) * typ.field_count(location, interner) } Type::Struct(def, args) => { let struct_type = def.borrow(); let fields = struct_type.get_fields(args); - fields.iter().fold(0, |acc, (_, field_type)| acc + field_type.field_count()) + fields.iter().fold(0, |acc, (_, field_type)| { + acc + field_type.field_count(location, interner) + }) } - Type::Alias(def, generics) => def.borrow().get_type(generics).field_count(), - Type::Tuple(fields) => { - fields.iter().fold(0, |acc, field_typ| acc + field_typ.field_count()) + Type::Alias(def, generics) => { + def.borrow().get_type(generics).field_count(location, interner) } + Type::Tuple(fields) => fields + .iter() + .fold(0, |acc, field_typ| acc + field_typ.field_count(location, interner)), Type::String(size) => { let size = size - .evaluate_to_u64() + .evaluate_to_u64(location, interner) .expect("Cannot have variable sized strings as a parameter to main"); size as u32 } @@ -149,6 +160,7 @@ impl Type { | Type::Function(_, _, _) | Type::MutableReference(_) | Type::Forall(_, _) + | Type::GenericArith(..) | Type::Constant(_) | Type::Code | Type::Slice(_) @@ -614,7 +626,11 @@ impl Type { use TypeVariableKind as K; matches!( self.follow_bindings(), - FieldElement | Integer(..) | Bool | TypeVariable(_, K::Integer | K::IntegerOrField) + FieldElement + | Integer(..) + | Bool + | TypeVariable(_, K::Integer | K::IntegerOrField) + | GenericArith(..) ) } @@ -676,6 +692,7 @@ impl Type { generic.contains_numeric_typevar(target_id) } }), + Type::GenericArith(_, generics) => generics.iter().any(named_generic_id_matches_target), Type::MutableReference(element) => element.contains_numeric_typevar(target_id), Type::String(length) => named_generic_id_matches_target(length), Type::FmtString(length, elements) => { @@ -709,6 +726,7 @@ impl Type { | Type::Function(_, _, _) | Type::MutableReference(_) | Type::Forall(_, _) + | Type::GenericArith(..) | Type::Code | Type::Slice(_) | Type::TraitAsType(..) => false, @@ -757,6 +775,7 @@ impl Type { | Type::Slice(_) | Type::MutableReference(_) | Type::Forall(_, _) + | Type::GenericArith(..) // TODO: probably can allow code as it is all compile time | Type::Code | Type::TraitAsType(..) => false, @@ -819,6 +838,22 @@ impl Type { other => (Cow::Owned(Generics::new()), other), } } + + /// Is this type or any of its arguments (recursively) GenericArith? + pub fn contains_generic_arith(&self) -> bool { + match self { + Type::Array(_len, typ) => typ.contains_generic_arith(), + Type::Slice(typ) => typ.contains_generic_arith(), + Type::Struct(_s, args) => args.iter().any(|arg| arg.contains_generic_arith()), + Type::Alias(_alias, args) => args.iter().any(|arg| arg.contains_generic_arith()), + Type::Tuple(elements) => { + elements.iter().any(|element| element.contains_generic_arith()) + } + Type::Forall(_typevars, typ) => typ.contains_generic_arith(), + Type::GenericArith(..) => true, + _ => false, + } + } } impl std::fmt::Display for Type { @@ -908,6 +943,10 @@ impl std::fmt::Display for Type { let typevars = vecmap(typevars, |var| var.id().to_string()); write!(f, "forall {}. {}", typevars.join(" "), typ) } + Type::GenericArith(arith_id, generics) => { + let generics = vecmap(generics, ToString::to_string); + write!(f, "arith {}. {:?}", generics.join(" "), arith_id) + } Type::Function(args, ret, env) => { let closure_env_text = match **env { Type::Unit => "".to_string(), @@ -953,6 +992,7 @@ impl std::fmt::Display for TypeBinding { } } +#[derive(Debug)] pub struct UnificationError; impl Type { @@ -1011,6 +1051,7 @@ impl Type { }, } } + Type::GenericArith(..) => Ok(()), _ => Err(UnificationError), } } @@ -1095,6 +1136,10 @@ impl Type { } } } + Type::GenericArith(..) => { + // all done in try_unify with copy/paste block + Ok(()) + } _ => Err(UnificationError), } } @@ -1136,13 +1181,17 @@ impl Type { } } - fn get_inner_type_variable(&self) -> Option> { + pub(crate) fn get_outer_type_variable(&self) -> Option { match self { - Type::TypeVariable(var, _) | Type::NamedGeneric(var, _) => Some(var.1.clone()), + Type::TypeVariable(var, _) | Type::NamedGeneric(var, _) => Some(var.clone()), _ => None, } } + fn get_inner_type_variable(&self) -> Option> { + self.get_outer_type_variable().map(|var| var.1) + } + /// Try to unify this type with another, setting any type variables found /// equal to the other type in the process. When comparing types, unification /// (including try_unify) are almost always preferred over Type::eq as unification @@ -1150,12 +1199,13 @@ impl Type { pub fn unify( &self, expected: &Type, + arith_constraints: &ArithConstraints, errors: &mut Vec, make_error: impl FnOnce() -> TypeCheckError, ) { let mut bindings = TypeBindings::new(); - match self.try_unify(expected, &mut bindings) { + match self.try_unify(expected, &mut bindings, arith_constraints) { Ok(()) => { // Commit any type bindings on success Self::apply_type_bindings(bindings); @@ -1170,6 +1220,7 @@ impl Type { &self, other: &Type, bindings: &mut TypeBindings, + arith_constraints: &ArithConstraints, ) -> Result<(), UnificationError> { use Type::*; use TypeVariableKind as Kind; @@ -1179,12 +1230,13 @@ impl Type { (Alias(alias, args), other) | (other, Alias(alias, args)) => { let alias = alias.borrow().get_type(args); - alias.try_unify(other, bindings) + alias.try_unify(other, bindings, arith_constraints) } (TypeVariable(var, Kind::IntegerOrField), other) | (other, TypeVariable(var, Kind::IntegerOrField)) => { - other.try_unify_to_type_variable(var, bindings, |bindings| { + other.try_unify_to_type_variable(var, bindings, arith_constraints, |bindings| { + other.try_unify_arith_generic_to_type_variable(var, arith_constraints); let only_integer = false; other.try_bind_to_polymorphic_int(var, bindings, only_integer) }) @@ -1192,36 +1244,39 @@ impl Type { (TypeVariable(var, Kind::Integer), other) | (other, TypeVariable(var, Kind::Integer)) => { - other.try_unify_to_type_variable(var, bindings, |bindings| { + other.try_unify_to_type_variable(var, bindings, arith_constraints, |bindings| { + other.try_unify_arith_generic_to_type_variable(var, arith_constraints); let only_integer = true; other.try_bind_to_polymorphic_int(var, bindings, only_integer) }) } (TypeVariable(var, Kind::Normal), other) | (other, TypeVariable(var, Kind::Normal)) => { - other.try_unify_to_type_variable(var, bindings, |bindings| { + other.try_unify_to_type_variable(var, bindings, arith_constraints, |bindings| { + other.try_unify_arith_generic_to_type_variable(var, arith_constraints); other.try_bind_to(var, bindings) }) } (TypeVariable(var, Kind::Constant(length)), other) | (other, TypeVariable(var, Kind::Constant(length))) => other - .try_unify_to_type_variable(var, bindings, |bindings| { + .try_unify_to_type_variable(var, bindings, arith_constraints, |bindings| { + other.try_unify_arith_generic_to_type_variable(var, arith_constraints); other.try_bind_to_maybe_constant(var, *length, bindings) }), (Array(len_a, elem_a), Array(len_b, elem_b)) => { - len_a.try_unify(len_b, bindings)?; - elem_a.try_unify(elem_b, bindings) + len_a.try_unify(len_b, bindings, arith_constraints)?; + elem_a.try_unify(elem_b, bindings, arith_constraints) } - (Slice(elem_a), Slice(elem_b)) => elem_a.try_unify(elem_b, bindings), + (Slice(elem_a), Slice(elem_b)) => elem_a.try_unify(elem_b, bindings, arith_constraints), - (String(len_a), String(len_b)) => len_a.try_unify(len_b, bindings), + (String(len_a), String(len_b)) => len_a.try_unify(len_b, bindings, arith_constraints), (FmtString(len_a, elements_a), FmtString(len_b, elements_b)) => { - len_a.try_unify(len_b, bindings)?; - elements_a.try_unify(elements_b, bindings) + len_a.try_unify(len_b, bindings, arith_constraints)?; + elements_a.try_unify(elements_b, bindings, arith_constraints) } (Tuple(elements_a), Tuple(elements_b)) => { @@ -1229,7 +1284,7 @@ impl Type { Err(UnificationError) } else { for (a, b) in elements_a.iter().zip(elements_b) { - a.try_unify(b, bindings)?; + a.try_unify(b, bindings, arith_constraints)?; } Ok(()) } @@ -1241,7 +1296,7 @@ impl Type { (Struct(id_a, args_a), Struct(id_b, args_b)) => { if id_a == id_b && args_a.len() == args_b.len() { for (a, b) in args_a.iter().zip(args_b) { - a.try_unify(b, bindings)?; + a.try_unify(b, bindings, arith_constraints)?; } Ok(()) } else { @@ -1253,7 +1308,7 @@ impl Type { if !binding.borrow().is_unbound() => { if let TypeBinding::Bound(link) = &*binding.borrow() { - link.try_unify(other, bindings) + link.try_unify(other, bindings, arith_constraints) } else { unreachable!("If guard ensures binding is bound") } @@ -1274,18 +1329,66 @@ impl Type { (Function(params_a, ret_a, env_a), Function(params_b, ret_b, env_b)) => { if params_a.len() == params_b.len() { for (a, b) in params_a.iter().zip(params_b.iter()) { - a.try_unify(b, bindings)?; + a.try_unify(b, bindings, arith_constraints)?; } - env_a.try_unify(env_b, bindings)?; - ret_b.try_unify(ret_a, bindings) + env_a.try_unify(env_b, bindings, arith_constraints)?; + ret_b.try_unify(ret_a, bindings, arith_constraints) } else { Err(UnificationError) } } (MutableReference(elem_a), MutableReference(elem_b)) => { - elem_a.try_unify(elem_b, bindings) + elem_a.try_unify(elem_b, bindings, arith_constraints) + } + + (GenericArith(lhs, lhs_generics), GenericArith(rhs, rhs_generics)) => { + if lhs == rhs && lhs_generics == rhs_generics { + return Ok(()); + } + + arith_constraints.borrow_mut().push(ArithConstraint { + lhs: *lhs, + lhs_generics: lhs_generics.to_vec(), + rhs: *rhs, + rhs_generics: rhs_generics.to_vec(), + needs_interning: NeedsInterning::Neither, + }); + + Ok(()) + } + + (GenericArith(lhs, lhs_generics), Constant(rhs)) => { + let rhs_expr = ArithExpr::Constant(*rhs); + let rhs = rhs_expr.to_id(); + let rhs_generics = vec![]; + + arith_constraints.borrow_mut().push(ArithConstraint { + lhs: *lhs, + lhs_generics: lhs_generics.to_vec(), + rhs, + rhs_generics, + needs_interning: NeedsInterning::Rhs(rhs_expr), + }); + + Ok(()) + } + + (Constant(lhs), GenericArith(rhs, rhs_generics)) => { + let lhs_expr = ArithExpr::Constant(*lhs); + let lhs = lhs_expr.to_id(); + let lhs_generics = vec![]; + + arith_constraints.borrow_mut().push(ArithConstraint { + lhs, + lhs_generics, + rhs: *rhs, + rhs_generics: rhs_generics.to_vec(), + needs_interning: NeedsInterning::Lhs(lhs_expr), + }); + + Ok(()) } (other_a, other_b) => { @@ -1298,12 +1401,34 @@ impl Type { } } + fn try_unify_arith_generic_to_type_variable( + &self, + var: &TypeVariable, + arith_constraints: &ArithConstraints, + ) { + if let Self::GenericArith(lhs, lhs_generics) = self { + let name: Rc = format!("#implicit_var({:?})", var).into(); + let rhs_expr = ArithExpr::Variable(var.clone(), name, Default::default()); + let rhs = rhs_expr.to_id(); + let rhs_generics = vec![]; + + arith_constraints.borrow_mut().push(ArithConstraint { + lhs: *lhs, + lhs_generics: lhs_generics.to_vec(), + rhs, + rhs_generics, + needs_interning: NeedsInterning::Rhs(rhs_expr), + }); + } + } + /// Try to unify a type variable to `self`. /// This is a helper function factored out from try_unify. fn try_unify_to_type_variable( &self, type_variable: &TypeVariable, bindings: &mut TypeBindings, + arith_constraints: &ArithConstraints, // Bind the type variable to a type. This is factored out since depending on the // TypeVariableKind, there are different methods to check whether the variable can @@ -1312,12 +1437,14 @@ impl Type { ) -> Result<(), UnificationError> { match &*type_variable.borrow() { // If it is already bound, unify against what it is bound to - TypeBinding::Bound(link) => link.try_unify(self, bindings), + TypeBinding::Bound(link) => link.try_unify(self, bindings, arith_constraints), TypeBinding::Unbound(id) => { // We may have already "bound" this type variable in this call to // try_unify, so check those bindings as well. match bindings.get(id) { - Some((_, binding)) => binding.clone().try_unify(self, bindings), + Some((_, binding)) => { + binding.clone().try_unify(self, bindings, arith_constraints) + } // Otherwise, bind it None => bind_variable(bindings), @@ -1341,7 +1468,9 @@ impl Type { ) { let mut bindings = TypeBindings::new(); - if let Err(UnificationError) = self.try_unify(expected, &mut bindings) { + if let Err(UnificationError) = + self.try_unify(expected, &mut bindings, &interner.arith_constraints) + { if !self.try_array_to_slice_coercion(expected, expression, interner) { errors.push(make_error()); } @@ -1365,7 +1494,7 @@ impl Type { // Still have to ensure the element types match. // Don't need to issue an error here if not, it will be done in unify_with_coercions let mut bindings = TypeBindings::new(); - if element1.try_unify(element2, &mut bindings).is_ok() { + if element1.try_unify(element2, &mut bindings, &interner.arith_constraints).is_ok() { convert_array_expression_to_slice(expression, this, target, interner); Self::apply_type_bindings(bindings); return true; @@ -1384,18 +1513,32 @@ impl Type { /// If this type is a Type::Constant (used in array lengths), or is bound /// to a Type::Constant, return the constant as a u64. - pub fn evaluate_to_u64(&self) -> Option { + // NOTE: only_used_in_recursion false positive + #[allow(clippy::only_used_in_recursion)] + pub fn evaluate_to_u64( + &self, + location: &Location, + interner: &NodeInterner, + ) -> Result { if let Some(binding) = self.get_inner_type_variable() { if let TypeBinding::Bound(binding) = &*binding.borrow() { - return binding.evaluate_to_u64(); + return binding.evaluate_to_u64(location, interner); } } match self { - Type::TypeVariable(_, TypeVariableKind::Constant(size)) => Some(*size), - Type::Array(len, _elem) => len.evaluate_to_u64(), - Type::Constant(x) => Some(*x), - _ => None, + Type::TypeVariable(_, TypeVariableKind::Constant(size)) => Ok(*size), + Type::Array(len, _elem) => len.evaluate_to_u64(location, interner), + Type::Constant(x) => Ok(*x), + Type::GenericArith(arith_id, generics) => { + let (expr, location) = interner.get_arith_expression(*arith_id); + let generics = + ArithConstraint::evaluate_generics_to_u64(generics, location, interner)?; + expr.evaluate(interner, &generics) + } + unexpected_type => Err(ArithExprError::EvaluateUnexpectedType { + unexpected_type: unexpected_type.clone(), + }), } } @@ -1618,6 +1761,12 @@ impl Type { let typ = Box::new(typ.substitute_helper(type_bindings, substitute_bound_typevars)); Type::Forall(typevars.clone(), typ) } + Type::GenericArith(arith_id, generics) => { + let generics = vecmap(generics, |generic| { + generic.substitute_helper(type_bindings, substitute_bound_typevars) + }); + Type::GenericArith(*arith_id, generics) + } Type::Function(args, ret, env) => { let args = vecmap(args, |arg| { arg.substitute_helper(type_bindings, substitute_bound_typevars) @@ -1673,6 +1822,9 @@ impl Type { Type::Forall(typevars, typ) => { !typevars.iter().any(|var| var.id() == target_id) && typ.occurs(target_id) } + Type::GenericArith(_, generics) => { + generics.iter().any(|generic| generic.occurs(target_id)) + } Type::Function(args, ret, env) => { args.iter().any(|arg| arg.occurs(target_id)) || ret.occurs(target_id) @@ -1733,6 +1885,9 @@ impl Type { Function(args, ret, env) } + GenericArith(arith_id, generics) => { + GenericArith(*arith_id, vecmap(generics, |generic| generic.follow_bindings())) + } MutableReference(element) => MutableReference(Box::new(element.follow_bindings())), TraitAsType(s, name, args) => { @@ -1824,7 +1979,11 @@ impl From<&Type> for PrintableType { match value { Type::FieldElement => PrintableType::Field, Type::Array(size, typ) => { - let length = size.evaluate_to_u64().expect("Cannot print variable sized arrays"); + let location = Location::dummy(); + let interner = NodeInterner::default(); + let length = size + .evaluate_to_u64(&location, &interner) + .expect("Cannot print variable sized arrays"); let typ = typ.as_ref(); PrintableType::Array { length, typ: Box::new(typ.into()) } } @@ -1850,7 +2009,11 @@ impl From<&Type> for PrintableType { } Type::Bool => PrintableType::Boolean, Type::String(size) => { - let size = size.evaluate_to_u64().expect("Cannot print variable sized strings"); + let location = Location::dummy(); + let interner = NodeInterner::default(); + let size = size + .evaluate_to_u64(&location, &interner) + .expect("Cannot print variable sized strings"); PrintableType::String { length: size } } Type::FmtString(_, _) => unreachable!("format strings cannot be printed"), @@ -1869,6 +2032,7 @@ impl From<&Type> for PrintableType { Type::TypeVariable(_, _) => unreachable!(), Type::NamedGeneric(..) => unreachable!(), Type::Forall(..) => unreachable!(), + Type::GenericArith(..) => unreachable!(), Type::Function(arguments, return_type, env) => PrintableType::Function { arguments: arguments.iter().map(|arg| arg.into()).collect(), return_type: Box::new(return_type.as_ref().into()), @@ -1949,6 +2113,10 @@ impl std::fmt::Debug for Type { let typevars = vecmap(typevars, |var| format!("{:?}", var)); write!(f, "forall {}. {:?}", typevars.join(" "), typ) } + Type::GenericArith(arith_id, generics) => { + let generics = vecmap(generics, |generic| format!("{:?}", generic)); + write!(f, "arith {}. {:?}", generics.join(" "), arith_id) + } Type::Function(args, ret, env) => { let closure_env_text = match **env { Type::Unit => "".to_string(), diff --git a/compiler/noirc_frontend/src/monomorphization/errors.rs b/compiler/noirc_frontend/src/monomorphization/errors.rs index 3011c26cffe..1405f074e65 100644 --- a/compiler/noirc_frontend/src/monomorphization/errors.rs +++ b/compiler/noirc_frontend/src/monomorphization/errors.rs @@ -2,6 +2,8 @@ use thiserror::Error; use noirc_errors::{CustomDiagnostic, FileDiagnostic, Location}; +use crate::node_interner::generic_arith::{ArithConstraintError, ArithExprError}; + #[derive(Debug, Error)] pub enum MonomorphizationError { #[error("Length of generic array could not be determined.")] @@ -9,13 +11,31 @@ pub enum MonomorphizationError { #[error("Type annotations needed")] TypeAnnotationsNeeded { location: Location }, + + #[error("Failed to prove generic arithmetic equivalent:\n{error}")] + ArithConstraintError { error: ArithConstraintError }, + + #[error("Failed to prove generic arithmetic equivalent:\n{arith_expr_error}")] + ArithExprError { arith_expr_error: ArithExprError, location: Location }, } impl MonomorphizationError { fn location(&self) -> Location { match self { MonomorphizationError::UnknownArrayLength { location } - | MonomorphizationError::TypeAnnotationsNeeded { location } => *location, + | MonomorphizationError::TypeAnnotationsNeeded { location } + | MonomorphizationError::ArithExprError { location, .. } => *location, + + MonomorphizationError::ArithConstraintError { error } => error.location(), + } + } + + fn other_locations(&self) -> Vec { + match self { + MonomorphizationError::UnknownArrayLength { .. } + | MonomorphizationError::TypeAnnotationsNeeded { .. } + | MonomorphizationError::ArithExprError { .. } => vec![], + MonomorphizationError::ArithConstraintError { error } => error.other_locations(), } } } @@ -23,12 +43,21 @@ impl MonomorphizationError { impl From for FileDiagnostic { fn from(error: MonomorphizationError) -> FileDiagnostic { let location = error.location(); - let call_stack = vec![location]; + let call_stack: Vec<_> = std::iter::once(location) + .chain(error.other_locations()) + .filter(|x| x != &Location::dummy()) + .collect(); let diagnostic = error.into_diagnostic(); diagnostic.in_file(location.file).with_call_stack(call_stack) } } +impl From for MonomorphizationError { + fn from(error: ArithConstraintError) -> Self { + Self::ArithConstraintError { error } + } +} + impl MonomorphizationError { fn into_diagnostic(self) -> CustomDiagnostic { let message = self.to_string(); diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 2e74eb87e60..1cf8b1b954d 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -144,7 +144,7 @@ pub fn monomorphize_debug( }) .collect(); - let functions = vecmap(monomorphizer.finished_functions, |(_, f)| f); + let functions = vecmap(monomorphizer.finished_functions.clone(), |(_, f)| f); let FuncMeta { return_visibility, kind, .. } = monomorphizer.interner.function_meta(&main); let (debug_variables, debug_functions, debug_types) = @@ -160,6 +160,10 @@ pub fn monomorphize_debug( debug_functions, debug_types, ); + + // we need to check the arith constraints + monomorphizer.validate_arith_constraints()?; + assert!(monomorphizer.interner.arith_constraints.borrow().is_empty()); Ok(program) } @@ -315,7 +319,7 @@ impl<'interner> Monomorphizer<'interner> { other => other, }; - let return_type = Self::convert_type(return_type, meta.location)?; + let return_type = Self::convert_type(return_type, meta.location, self.interner)?; let unconstrained = modifiers.is_unconstrained; let attributes = self.interner.function_attributes(&f); @@ -367,7 +371,7 @@ impl<'interner> Monomorphizer<'interner> { let new_id = self.next_local_id(); let definition = self.interner.definition(ident.id); let name = definition.name.clone(); - let typ = Self::convert_type(typ, ident.location)?; + let typ = Self::convert_type(typ, ident.location, self.interner)?; new_params.push((new_id, definition.mutable, name, typ)); self.define_local(ident.id, new_id); } @@ -421,7 +425,8 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Literal(HirLiteral::Bool(value)) => Literal(Bool(value)), HirExpression::Literal(HirLiteral::Integer(value, sign)) => { let location = self.interner.id_location(expr); - let typ = Self::convert_type(&self.interner.id_type(expr), location)?; + let typ = + Self::convert_type(&self.interner.id_type(expr), location, self.interner)?; if sign { match typ { @@ -457,7 +462,11 @@ impl<'interner> Monomorphizer<'interner> { ast::Expression::Unary(ast::Unary { operator: prefix.operator, rhs: Box::new(self.expr(prefix.rhs)?), - result_type: Self::convert_type(&self.interner.id_type(expr), location)?, + result_type: Self::convert_type( + &self.interner.id_type(expr), + location, + self.interner, + )?, location, }) } @@ -507,7 +516,7 @@ impl<'interner> Monomorphizer<'interner> { HirExpression::Cast(cast) => { let location = self.interner.expr_location(&expr); - let typ = Self::convert_type(&cast.r#type, location)?; + let typ = Self::convert_type(&cast.r#type, location, self.interner)?; let lhs = Box::new(self.expr(cast.lhs)?); ast::Expression::Cast(ast::Cast { lhs, r#type: typ, location }) } @@ -519,7 +528,8 @@ impl<'interner> Monomorphizer<'interner> { if_expr.alternative.map(|alt| self.expr(alt)).transpose()?.map(Box::new); let location = self.interner.expr_location(&expr); - let typ = Self::convert_type(&self.interner.id_type(expr), location)?; + let typ = + Self::convert_type(&self.interner.id_type(expr), location, self.interner)?; ast::Expression::If(ast::If { condition, consequence, alternative: else_, typ }) } @@ -554,7 +564,7 @@ impl<'interner> Monomorphizer<'interner> { is_slice: bool, ) -> Result { let location = self.interner.expr_location(&array); - let typ = Self::convert_type(&self.interner.id_type(array), location)?; + let typ = Self::convert_type(&self.interner.id_type(array), location, self.interner)?; let contents = try_vecmap(array_elements, |id| self.expr(id))?; if is_slice { Ok(ast::Expression::Literal(ast::Literal::Slice(ast::ArrayLiteral { contents, typ }))) @@ -571,12 +581,12 @@ impl<'interner> Monomorphizer<'interner> { is_slice: bool, ) -> Result { let location = self.interner.expr_location(&array); - let typ = Self::convert_type(&self.interner.id_type(array), location)?; + let typ = Self::convert_type(&self.interner.id_type(array), location, self.interner)?; - let length = length.evaluate_to_u64().ok_or_else(|| { - let location = self.interner.expr_location(&array); - MonomorphizationError::UnknownArrayLength { location } - })?; + let length = + length.evaluate_to_u64(&location, self.interner).map_err(|arith_expr_error| { + MonomorphizationError::ArithExprError { arith_expr_error, location } + })?; let contents = try_vecmap(0..length, |_| self.expr(repeated_element))?; if is_slice { @@ -592,7 +602,7 @@ impl<'interner> Monomorphizer<'interner> { index: HirIndexExpression, ) -> Result { let location = self.interner.expr_location(&id); - let element_type = Self::convert_type(&self.interner.id_type(id), location)?; + let element_type = Self::convert_type(&self.interner.id_type(id), location, self.interner)?; let collection = Box::new(self.expr(index.collection)?); let index = Box::new(self.expr(index.index)?); @@ -629,7 +639,7 @@ impl<'interner> Monomorphizer<'interner> { let block = Box::new(self.expr(for_loop.block)?); let index_location = for_loop.identifier.location; let index_type = self.interner.id_type(for_loop.start_range); - let index_type = Self::convert_type(&index_type, index_location)?; + let index_type = Self::convert_type(&index_type, index_location, self.interner)?; Ok(ast::Expression::For(ast::For { index_variable, @@ -683,7 +693,7 @@ impl<'interner> Monomorphizer<'interner> { let new_id = self.next_local_id(); let field_type = field_type_map.get(&field_name.0.contents).unwrap(); let location = self.interner.expr_location(&expr_id); - let typ = Self::convert_type(field_type, location)?; + let typ = Self::convert_type(field_type, location, self.interner)?; field_vars.insert(field_name.0.contents.clone(), (new_id, typ)); let expression = Box::new(self.expr(expr_id)?); @@ -783,7 +793,7 @@ impl<'interner> Monomorphizer<'interner> { let mutable = false; let definition = Definition::Local(fresh_id); let name = i.to_string(); - let typ = Self::convert_type(&field_type, location)?; + let typ = Self::convert_type(&field_type, location, self.interner)?; let location = Some(location); let new_rhs = @@ -832,7 +842,11 @@ impl<'interner> Monomorphizer<'interner> { return Ok(None); }; - let typ = Self::convert_type(&self.interner.definition_type(ident.id), ident.location)?; + let typ = Self::convert_type( + &self.interner.definition_type(ident.id), + ident.location, + self.interner, + )?; Ok(Some(ast::Ident { location: Some(ident.location), mutable, definition, name, typ })) } @@ -861,7 +875,7 @@ impl<'interner> Monomorphizer<'interner> { generics.unwrap_or_default(), None, ); - let typ = Self::convert_type(&typ, ident.location)?; + let typ = Self::convert_type(&typ, ident.location, self.interner)?; let ident = ast::Ident { location, mutable, definition, name, typ: typ.clone() }; let ident_expression = ast::Expression::Ident(ident); if self.is_function_closure_type(&typ) { @@ -896,14 +910,17 @@ impl<'interner> Monomorphizer<'interner> { TypeBinding::Unbound(_) => { unreachable!("Unbound type variable used in expression") } - TypeBinding::Bound(binding) => binding.evaluate_to_u64().unwrap_or_else(|| { - panic!("Non-numeric type variable used in expression expecting a value") - }), + TypeBinding::Bound(binding) => { + let location = self.interner.id_location(expr_id); + binding.evaluate_to_u64(&location, self.interner).unwrap_or_else(|err| { + panic!("Non-numeric type variable used in expression expecting a value: {:?}", err) + }) + } }; let value = FieldElement::from(value as u128); let location = self.interner.id_location(expr_id); - let typ = Self::convert_type(&typ, ident.location)?; + let typ = Self::convert_type(&typ, ident.location, self.interner)?; ast::Expression::Literal(ast::Literal::Integer(value, typ, location)) } }; @@ -912,28 +929,39 @@ impl<'interner> Monomorphizer<'interner> { } /// Convert a non-tuple/struct type to a monomorphized type - fn convert_type(typ: &HirType, location: Location) -> Result { + fn convert_type( + typ: &HirType, + location: Location, + interner: &NodeInterner, + ) -> Result { Ok(match typ { HirType::FieldElement => ast::Type::Field, HirType::Integer(sign, bits) => ast::Type::Integer(*sign, *bits), HirType::Bool => ast::Type::Bool, - HirType::String(size) => ast::Type::String(size.evaluate_to_u64().unwrap_or(0)), + HirType::String(size) => { + ast::Type::String(size.evaluate_to_u64(&location, interner).unwrap_or(0)) + } HirType::FmtString(size, fields) => { - let size = size.evaluate_to_u64().unwrap_or(0); - let fields = Box::new(Self::convert_type(fields.as_ref(), location)?); + let size = size.evaluate_to_u64(&location, interner).unwrap_or(0); + let fields = Box::new(Self::convert_type(fields.as_ref(), location, interner)?); ast::Type::FmtString(size, fields) } HirType::Unit => ast::Type::Unit, HirType::Array(length, element) => { - let element = Box::new(Self::convert_type(element.as_ref(), location)?); - let length = match length.evaluate_to_u64() { - Some(length) => length, - None => return Err(MonomorphizationError::TypeAnnotationsNeeded { location }), + let element = Box::new(Self::convert_type(element.as_ref(), location, interner)?); + let length = match length.evaluate_to_u64(&location, interner) { + Ok(length) => length, + Err(arith_expr_error) => { + return Err(MonomorphizationError::ArithExprError { + arith_expr_error, + location, + }); + } }; ast::Type::Array(length, element) } HirType::Slice(element) => { - let element = Box::new(Self::convert_type(element.as_ref(), location)?); + let element = Box::new(Self::convert_type(element.as_ref(), location, interner)?); ast::Type::Slice(element) } HirType::TraitAsType(..) => { @@ -941,7 +969,7 @@ impl<'interner> Monomorphizer<'interner> { } HirType::NamedGeneric(binding, _) => { if let TypeBinding::Bound(binding) = &*binding.borrow() { - return Self::convert_type(binding, location); + return Self::convert_type(binding, location, interner); } // Default any remaining unbound type variables. @@ -953,7 +981,7 @@ impl<'interner> Monomorphizer<'interner> { HirType::TypeVariable(binding, kind) => { if let TypeBinding::Bound(binding) = &*binding.borrow() { - return Self::convert_type(binding, location); + return Self::convert_type(binding, location, interner); } // Default any remaining unbound type variables. @@ -964,30 +992,32 @@ impl<'interner> Monomorphizer<'interner> { None => return Err(MonomorphizationError::TypeAnnotationsNeeded { location }), }; - let monomorphized_default = Self::convert_type(&default, location)?; + let monomorphized_default = Self::convert_type(&default, location, interner)?; binding.bind(default); monomorphized_default } HirType::Struct(def, args) => { let fields = def.borrow().get_fields(args); - let fields = try_vecmap(fields, |(_, field)| Self::convert_type(&field, location))?; + let fields = try_vecmap(fields, |(_, field)| { + Self::convert_type(&field, location, interner) + })?; ast::Type::Tuple(fields) } HirType::Alias(def, args) => { - Self::convert_type(&def.borrow().get_type(args), location)? + Self::convert_type(&def.borrow().get_type(args), location, interner)? } HirType::Tuple(fields) => { - let fields = try_vecmap(fields, |x| Self::convert_type(x, location))?; + let fields = try_vecmap(fields, |x| Self::convert_type(x, location, interner))?; ast::Type::Tuple(fields) } HirType::Function(args, ret, env) => { - let args = try_vecmap(args, |x| Self::convert_type(x, location))?; - let ret = Box::new(Self::convert_type(ret, location)?); - let env = Self::convert_type(env, location)?; + let args = try_vecmap(args, |x| Self::convert_type(x, location, interner))?; + let ret = Box::new(Self::convert_type(ret, location, interner)?); + let env = Self::convert_type(env, location, interner)?; match &env { ast::Type::Unit => ast::Type::Function(args, ret, Box::new(env)), ast::Type::Tuple(_elements) => ast::Type::Tuple(vec![ @@ -1002,8 +1032,22 @@ impl<'interner> Monomorphizer<'interner> { } } + HirType::GenericArith(_arith_id, generics) => { + try_vecmap(generics, |generic| { + match generic.evaluate_to_u64(&location, interner) { + Ok(result) => Ok(result), + Err(arith_expr_error) => Err(MonomorphizationError::ArithExprError { + arith_expr_error, + location, + }), + } + })?; + + ast::Type::Field + } + HirType::MutableReference(element) => { - let element = Self::convert_type(element, location)?; + let element = Self::convert_type(element, location, interner)?; ast::Type::MutableReference(Box::new(element)) } @@ -1099,7 +1143,7 @@ impl<'interner> Monomorphizer<'interner> { mutable: false, location: None, name: the_trait.methods[method.method_index].name.0.contents.clone(), - typ: Self::convert_type(&function_type, location)?, + typ: Self::convert_type(&function_type, location, self.interner)?, })) } @@ -1116,7 +1160,7 @@ impl<'interner> Monomorphizer<'interner> { let return_type = self.interner.id_type(id); let location = self.interner.expr_location(&id); - let return_type = Self::convert_type(&return_type, location)?; + let return_type = Self::convert_type(&return_type, location, self.interner)?; let location = call.location; @@ -1133,7 +1177,7 @@ impl<'interner> Monomorphizer<'interner> { let mut block_expressions = vec![]; let func_type = self.interner.id_type(call.func); - let func_type = Self::convert_type(&func_type, location)?; + let func_type = Self::convert_type(&func_type, location, self.interner)?; let is_closure = self.is_function_closure(func_type); let func = if is_closure { @@ -1155,7 +1199,11 @@ impl<'interner> Monomorphizer<'interner> { definition: Definition::Local(local_id), mutable: false, name: "tmp".to_string(), - typ: Self::convert_type(&self.interner.id_type(call.func), location)?, + typ: Self::convert_type( + &self.interner.id_type(call.func), + location, + self.interner, + )?, }); let env_argument = @@ -1367,12 +1415,12 @@ impl<'interner> Monomorphizer<'interner> { HirLValue::Index { array, index, typ, location } => { let array = Box::new(self.lvalue(*array)?); let index = Box::new(self.expr(index)?); - let element_type = Self::convert_type(&typ, location)?; + let element_type = Self::convert_type(&typ, location, self.interner)?; ast::LValue::Index { array, index, element_type, location } } HirLValue::Dereference { lvalue, element_type, location } => { let reference = Box::new(self.lvalue(*lvalue)?); - let element_type = Self::convert_type(&element_type, location)?; + let element_type = Self::convert_type(&element_type, location, self.interner)?; ast::LValue::Dereference { reference, element_type } } }; @@ -1399,10 +1447,11 @@ impl<'interner> Monomorphizer<'interner> { expr: node_interner::ExprId, ) -> Result { let location = self.interner.expr_location(&expr); - let ret_type = Self::convert_type(&lambda.return_type, location)?; + let ret_type = Self::convert_type(&lambda.return_type, location, self.interner)?; let lambda_name = "lambda"; - let parameter_types = - try_vecmap(&lambda.parameters, |(_, typ)| Self::convert_type(typ, location))?; + let parameter_types = try_vecmap(&lambda.parameters, |(_, typ)| { + Self::convert_type(typ, location, self.interner) + })?; // Manually convert to Parameters type so we can reuse the self.parameters method let parameters = @@ -1461,10 +1510,11 @@ impl<'interner> Monomorphizer<'interner> { // which seems more fragile, we directly reuse the return parameters // of this function in those cases let location = self.interner.expr_location(&expr); - let ret_type = Self::convert_type(&lambda.return_type, location)?; + let ret_type = Self::convert_type(&lambda.return_type, location, self.interner)?; let lambda_name = "lambda"; - let parameter_types = - try_vecmap(&lambda.parameters, |(_, typ)| Self::convert_type(typ, location))?; + let parameter_types = try_vecmap(&lambda.parameters, |(_, typ)| { + Self::convert_type(typ, location, self.interner) + })?; // Manually convert to Parameters type so we can reuse the self.parameters method let parameters = @@ -1487,7 +1537,10 @@ impl<'interner> Monomorphizer<'interner> { ); let ident = Box::new(ast::Expression::Ident(lambda_ctx.env_ident.clone())); - Ok(ast::Expression::ExtractTupleField(ident, field_index)) + Ok::<_, MonomorphizationError>(ast::Expression::ExtractTupleField( + ident, + field_index, + )) } None => { let ident = self.local_ident(&capture.ident)?.unwrap(); @@ -1498,7 +1551,7 @@ impl<'interner> Monomorphizer<'interner> { let expr_type = self.interner.id_type(expr); let env_typ = if let types::Type::Function(_, _, function_env_type) = expr_type { - Self::convert_type(&function_env_type, location)? + Self::convert_type(&function_env_type, location, self.interner)? } else { unreachable!("expected a Function type for a Lambda node") }; @@ -1708,7 +1761,7 @@ impl<'interner> Monomorphizer<'interner> { ) -> Result { let arguments = vec![lhs, rhs]; let func = Box::new(func); - let return_type = Self::convert_type(&ret, location)?; + let return_type = Self::convert_type(&ret, location, self.interner)?; let mut result = ast::Expression::Call(ast::Call { func, arguments, return_type, location }); @@ -1762,7 +1815,7 @@ impl<'interner> Monomorphizer<'interner> { /// static method references to generic impls (e.g. `Eq::eq` for `[T; N]`) will fail to re-apply /// the correct type bindings during monomorphization. fn perform_impl_bindings( - &self, + &mut self, trait_method: Option, impl_method: node_interner::FuncId, ) -> TypeBindings { @@ -1786,7 +1839,7 @@ impl<'interner> Monomorphizer<'interner> { let type_bindings = generics.iter().map(replace_type_variable).collect(); let impl_method_type = impl_method_type.force_substitute(&type_bindings); - trait_method_type.try_unify(&impl_method_type, &mut bindings).unwrap_or_else(|_| { + trait_method_type.try_unify(&impl_method_type, &mut bindings, &self.interner.arith_constraints).unwrap_or_else(|_| { unreachable!("Impl method type {} does not unify with trait method type {} during monomorphization", impl_method_type, trait_method_type) }); @@ -1795,6 +1848,15 @@ impl<'interner> Monomorphizer<'interner> { bindings } + + /// Validate and consume all of the `ArithConstraints`` + fn validate_arith_constraints(&mut self) -> Result<(), MonomorphizationError> { + let arith_constraints = self.interner.arith_constraints.replace(vec![]); + for arith_constraint in arith_constraints { + arith_constraint.validate(self.interner)?; + } + Ok(()) + } } fn unwrap_tuple_type(typ: &HirType) -> Vec { diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index e28a0a64ad0..034c375e878 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use std::collections::HashMap; use std::fmt; +use std::hash::Hash; use std::ops::Deref; use fm::FileId; @@ -33,6 +34,10 @@ use crate::{ Generics, Shared, TypeAlias, TypeBindings, TypeVariable, TypeVariableId, TypeVariableKind, }; +pub mod generic_arith; + +pub use generic_arith::{ArithConstraints, ArithExpr, ArithId}; + /// An arbitrary number to limit the recursion depth when searching for trait impls. /// This is needed to stop recursing for cases such as `impl Foo for T where T: Eq` const IMPL_SEARCH_RECURSION_LIMIT: u32 = 10; @@ -151,6 +156,10 @@ pub struct NodeInterner { globals: Vec, global_attributes: HashMap>, + arith_expressions: HashMap, + + pub(crate) arith_constraints: ArithConstraints, + next_type_variable_id: std::cell::Cell, /// A map from a struct type and method name to a function id for the method. @@ -498,6 +507,8 @@ impl Default for NodeInterner { next_type_variable_id: std::cell::Cell::new(0), globals: Vec::new(), global_attributes: HashMap::new(), + arith_expressions: HashMap::new(), + arith_constraints: Vec::new().into(), struct_methods: HashMap::new(), primitive_methods: HashMap::new(), type_alias_ref: Vec::new(), @@ -706,6 +717,18 @@ impl NodeInterner { id } + pub fn push_arith_expression(&mut self, expr: ArithExpr, location: Location) -> ArithId { + let arith_id = expr.to_id(); + self.arith_expressions.insert(arith_id, (expr, location)); + arith_id + } + + pub fn get_arith_expression(&self, arith_id: ArithId) -> &(ArithExpr, Location) { + self.arith_expressions.get(&arith_id).unwrap_or_else(|| { + panic!("ICE: unknown ArithId ({:?})\n\n{:?}", arith_id, self.arith_expressions) + }) + } + /// Intern an empty function. pub fn push_empty_fn(&mut self) -> FuncId { self.push_fn(HirFunction::empty()) @@ -1161,7 +1184,7 @@ impl NodeInterner { /// failing one. /// If this list of failing constraints is empty, this means type annotations are required. pub fn lookup_trait_implementation( - &self, + &mut self, object_type: &Type, trait_id: TraitId, trait_generics: &[Type], @@ -1207,7 +1230,7 @@ impl NodeInterner { /// Each constraint after the first represents a `where` clause that was followed. /// - 0 trait constraints indicating type annotations are needed to choose an impl. pub fn try_lookup_trait_implementation( - &self, + &mut self, object_type: &Type, trait_id: TraitId, trait_generics: &[Type], @@ -1266,7 +1289,9 @@ impl NodeInterner { let mut check_trait_generics = |impl_generics: &[Type]| { trait_generics.iter().zip(impl_generics).all(|(trait_generic, impl_generic2)| { let impl_generic = impl_generic2.substitute(&instantiation_bindings); - trait_generic.try_unify(&impl_generic, &mut fresh_bindings).is_ok() + trait_generic + .try_unify(&impl_generic, &mut fresh_bindings, &self.arith_constraints) + .is_ok() }) }; @@ -1285,7 +1310,10 @@ impl NodeInterner { continue; } - if object_type.try_unify(&existing_object_type, &mut fresh_bindings).is_ok() { + if object_type + .try_unify(&existing_object_type, &mut fresh_bindings, &self.arith_constraints) + .is_ok() + { if let TraitImplKind::Normal(impl_id) = impl_kind { let trait_impl = self.get_trait_implementation(*impl_id); let trait_impl = trait_impl.borrow(); @@ -1466,20 +1494,20 @@ impl NodeInterner { // Failed to find a match for the type in question, switch to looking at impls // for all types `T`, e.g. `impl Foo for T` let key = &(TypeMethodKey::Generic, method_name.to_owned()); - let global_methods = self.primitive_methods.get(key)?; + let global_methods = self.primitive_methods.get(key)?.clone(); global_methods.find_matching_method(typ, self) } } /// Looks up a given method name on the given primitive type. - pub fn lookup_primitive_method(&self, typ: &Type, method_name: &str) -> Option { + pub fn lookup_primitive_method(&mut self, typ: &Type, method_name: &str) -> Option { let key = get_type_method_key(typ)?; - let methods = self.primitive_methods.get(&(key, method_name.to_owned()))?; - self.find_matching_method(typ, Some(methods), method_name) + let methods = self.primitive_methods.get(&(key, method_name.to_owned()))?.clone(); + self.find_matching_method(typ, Some(&methods), method_name) } pub fn lookup_primitive_trait_method_mut( - &self, + &mut self, typ: &Type, method_name: &str, ) -> Option { @@ -1744,7 +1772,8 @@ impl Methods { if let Some(object) = args.first() { let mut bindings = TypeBindings::new(); - if object.try_unify(typ, &mut bindings).is_ok() { + if object.try_unify(typ, &mut bindings, &interner.arith_constraints).is_ok() + { Type::apply_type_bindings(bindings); return Some(method); } @@ -1800,6 +1829,7 @@ fn get_type_method_key(typ: &Type) -> Option { // We do not support adding methods to these types Type::TypeVariable(_, _) | Type::Forall(_, _) + | Type::GenericArith(..) | Type::Constant(_) | Type::Error | Type::Struct(_, _) diff --git a/compiler/noirc_frontend/src/node_interner/generic_arith.rs b/compiler/noirc_frontend/src/node_interner/generic_arith.rs new file mode 100644 index 00000000000..928af058f58 --- /dev/null +++ b/compiler/noirc_frontend/src/node_interner/generic_arith.rs @@ -0,0 +1,584 @@ +use std::cell::RefCell; +use std::cmp::Ordering; +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; +use std::rc::Rc; + +use noirc_errors::Location; + +use crate::hir_def::types::Type; +use crate::{BinaryTypeOperator, TypeBinding, TypeBindings, TypeVariable, TypeVariableKind}; + +use super::NodeInterner; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum ArithId { + Dummy, + Hash(u64), +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash, Default)] +pub struct ArithGenericId(usize); + +impl ArithGenericId { + fn offset(&self, offset_amount: usize) -> Self { + ArithGenericId(self.0 + offset_amount) + } +} + +/// An arithmetic expression can be a variable, constant, or binary operation. +/// +/// An ArithExpr::Variable contains a NamedGeneric's TypeVariable and name, +/// as well as the ArithGenericId that points to the corresponding TypeVariable +/// in Type::GenericArith +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub enum ArithExpr { + Op { kind: ArithOpKind, lhs: Box, rhs: Box }, + Variable(TypeVariable, Rc, ArithGenericId), + Constant(u64), +} + +impl ArithExpr { + pub fn try_constant(&self) -> Option { + match self { + Self::Constant(x) => Some(*x), + _ => None, + } + } + + pub fn evaluate( + &self, + _interner: &NodeInterner, + arguments: &Vec<(u64, Type)>, + ) -> Result { + match self { + Self::Op { kind, lhs, rhs } => { + // TODO: interner unused, see https://github.com/noir-lang/noir/issues/5150 + let interner = NodeInterner::default(); + let lhs = lhs.evaluate(&interner, arguments)?; + let rhs = rhs.evaluate(&interner, arguments)?; + kind.evaluate(lhs, rhs) + } + Self::Variable(binding, name, index) => { + if let Some((result, _other_var)) = arguments.get(index.0) { + // // TODO: assertion fails https://github.com/noir-lang/noir/issues/5150 + // // (remove other_var if unneeded) + // + // let mut fresh_bindings = TypeBindings::new(); + // assert!(Type::NamedGeneric(binding.clone(), name.clone()) + // .try_unify(other_var, &mut fresh_bindings, &interner.arith_constraints,) + // .is_ok()); + + Ok(*result) + } else { + Err(ArithExprError::UnboundVariable { + binding: binding.clone(), + name: name.to_string(), + }) + } + } + Self::Constant(result) => Ok(*result), + } + } + + /// Apply Type::follow_bindings to each named generic + /// and return the updated version as well as any new generics + fn follow_bindings( + &self, + interner: &NodeInterner, + offset_amount: &mut usize, + ) -> (Self, Vec) { + match self { + Self::Op { kind, lhs, rhs } => { + let (lhs, mut lhs_new_generics) = lhs.follow_bindings(interner, offset_amount); + let (rhs, mut rhs_new_generics) = rhs.follow_bindings(interner, offset_amount); + let rhs = rhs.offset_generic_indices(lhs_new_generics.len()); + lhs_new_generics.append(&mut rhs_new_generics); + (Self::Op { kind: *kind, lhs: Box::new(lhs), rhs: Box::new(rhs) }, lhs_new_generics) + } + Self::Variable(binding, name, index) => { + match Type::NamedGeneric(binding.clone(), name.clone()).follow_bindings() { + Type::GenericArith(arith_id, generics) => { + let (arith_expr, _location) = interner.get_arith_expression(arith_id); + let arith_expr = arith_expr.offset_generic_indices(*offset_amount); + *offset_amount = arith_expr.max_generic_index().0; + (arith_expr, generics) + } + + Type::NamedGeneric(new_binding, new_name) => (Self::Variable(new_binding, new_name, *index), vec![]), + Type::TypeVariable(_new_binding, TypeVariableKind::Constant(value)) => (Self::Constant(value), vec![]), + Type::TypeVariable(new_binding, _kind) => { + Self::Variable(new_binding, name.clone(), *index).follow_bindings(interner, offset_amount) + } + Type::Constant(value) => { + (ArithExpr::Constant(value), vec![]) + } + other => panic!("ICE: follow_bindings on Type::NamedGeneric produced a result other than a variable or constant: {:?}", other), + } + } + Self::Constant(result) => (Self::Constant(*result), vec![]), + } + } + + /// map over Self::Variable's + fn map_variables(&self, f: &mut F) -> Self + where + F: FnMut( + &TypeVariable, + &Rc, + ArithGenericId, + ) -> (TypeVariable, Rc, ArithGenericId), + { + match self { + Self::Op { kind, lhs, rhs } => { + let lhs = Box::new(lhs.map_variables(f)); + let rhs = Box::new(rhs.map_variables(f)); + Self::Op { kind: *kind, lhs, rhs } + } + Self::Variable(binding, name, index) => { + let (new_binding, new_name, new_index) = f(binding, name, *index); + Self::Variable(new_binding, new_name, new_index) + } + Self::Constant(result) => Self::Constant(*result), + } + } + + /// normal form: sort nodes at each branch + fn nf(&self) -> Self { + match self { + Self::Op { kind, lhs, rhs } => { + match kind { + // commutative cases + ArithOpKind::Add | ArithOpKind::Mul => { + let (lhs, rhs) = if lhs <= rhs { + (lhs.clone(), rhs.clone()) + } else { + (rhs.clone(), lhs.clone()) + }; + Self::Op { kind: *kind, lhs, rhs } + } + _ => Self::Op { kind: *kind, lhs: lhs.clone(), rhs: rhs.clone() }, + } + } + other => other.clone(), + } + } + + /// Replace `TypeVariable`s`in Self::Variable with the given Type::TypeVariable's, + /// indexed by `ArithGenericId`` + fn impute_variables(&self, generics: &[Type]) -> Self { + self.map_variables(&mut |_var: &TypeVariable, name: &Rc, index: ArithGenericId| { + let new_var = generics + .get(index.0) + .expect("all variables in a GenericArith ArithExpr to be in the included Vec") + .get_outer_type_variable() + .expect("all args to GenericArith to be NamedGeneric/TypeVariable's"); + (new_var, name.clone(), index) + }) + } + + pub(crate) fn to_id(&self) -> ArithId { + let mut hasher = DefaultHasher::new(); + self.hash(&mut hasher); + ArithId::Hash(hasher.finish()) + } + + pub(crate) fn offset_generic_indices(&self, offset_amount: usize) -> Self { + match self { + Self::Op { kind, lhs, rhs } => { + let lhs = Box::new(lhs.offset_generic_indices(offset_amount)); + let rhs = Box::new(rhs.offset_generic_indices(offset_amount)); + Self::Op { kind: *kind, lhs, rhs } + } + Self::Variable(binding, name, index) => { + Self::Variable(binding.clone(), name.clone(), index.offset(offset_amount)) + } + Self::Constant(result) => Self::Constant(*result), + } + } + + pub(crate) fn max_generic_index(&self) -> ArithGenericId { + match self { + Self::Op { lhs, rhs, .. } => { + let lhs_max = lhs.max_generic_index(); + let rhs_max = rhs.max_generic_index(); + std::cmp::max(lhs_max, rhs_max) + } + Self::Variable(_binding, _name, index) => *index, + Self::Constant(_result) => ArithGenericId::default(), + } + } +} + +impl std::fmt::Display for ArithExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ArithExpr::Op { kind, lhs, rhs } => write!(f, "{lhs} {kind} {rhs}"), + ArithExpr::Variable(binding, name, _index) => match &*binding.borrow() { + TypeBinding::Bound(binding) => binding.fmt(f), + TypeBinding::Unbound(_) if name.is_empty() => write!(f, "_"), + TypeBinding::Unbound(_) => write!(f, "{name}"), + }, + ArithExpr::Constant(x) => x.fmt(f), + } + } +} + +/// Constant < Variable < Op +impl PartialOrd for ArithExpr { + fn partial_cmp(&self, other: &Self) -> Option { + match self { + Self::Op { kind, lhs, rhs } => match other { + Self::Op { kind: other_kind, lhs: other_lhs, rhs: other_rhs } => { + (kind, lhs, rhs).partial_cmp(&(other_kind, other_lhs, other_rhs)) + } + Self::Variable(..) => Some(Ordering::Greater), + Self::Constant(..) => Some(Ordering::Greater), + }, + Self::Variable(binding, name, index) => match other { + Self::Op { .. } => Some(Ordering::Less), + Self::Variable(other_binding, other_name, other_index) => ( + binding.id().0, + name, + index, + ) + .partial_cmp(&(other_binding.id().0, other_name, other_index)), + Self::Constant(..) => Some(Ordering::Greater), + }, + Self::Constant(self_result) => match other { + Self::Op { .. } => Some(Ordering::Less), + Self::Variable(..) => Some(Ordering::Less), + Self::Constant(other_result) => self_result.partial_cmp(other_result), + }, + } + } +} + +/// A binary operation that's allowed in an ArithExpr +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Copy, Clone, Hash)] +pub enum ArithOpKind { + Mul, + Add, + Sub, +} + +impl ArithOpKind { + /// Returns an error on overflow/underflow + pub fn evaluate(&self, x: u64, y: u64) -> Result { + match self { + Self::Mul => Ok(x * y), + Self::Add => Ok(x + y), + Self::Sub => x.checked_sub(y).ok_or(ArithExprError::SubUnderflow { lhs: x, rhs: y }), + } + } + + pub fn from_binary_type_operator(value: BinaryTypeOperator) -> Option { + match value { + BinaryTypeOperator::Addition => Some(ArithOpKind::Add), + BinaryTypeOperator::Multiplication => Some(ArithOpKind::Mul), + BinaryTypeOperator::Subtraction => Some(ArithOpKind::Sub), + _ => None, + } + } +} + +impl std::fmt::Display for ArithOpKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ArithOpKind::Mul => write!(f, "*"), + ArithOpKind::Add => write!(f, "+"), + ArithOpKind::Sub => write!(f, "-"), + } + } +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum ArithExprError { + SubUnderflow { lhs: u64, rhs: u64 }, + + UnboundVariable { binding: TypeVariable, name: String }, + + EvaluateUnexpectedType { unexpected_type: Type }, +} + +impl std::fmt::Display for ArithExprError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::SubUnderflow { lhs, rhs } => { + write!(f, "subtracting {} - {} underflowed", lhs, rhs) + } + Self::UnboundVariable { binding, name } => { + let use_elaborator_notice = + // TODO: https://github.com/noir-lang/noir/issues/5149 + "\nIf you're seeing this error inside of a definition that type checks,\n".to_owned() + + "running with '--use-elaborator' may fix it. (See Issue#5149 for more info)"; + if let TypeBinding::Unbound(_) = &*binding.borrow() { + write!( + f, + "unbound variable when resolving generic arithmetic: {}{}", + name, use_elaborator_notice + ) + } else { + write!( + f, + "unbound variable when resolving generic arithmetic: {}{}", + binding.borrow(), + use_elaborator_notice + ) + } + } + Self::EvaluateUnexpectedType { unexpected_type } => { + write!(f, "unexpected type when evaluating to u64: {}", unexpected_type) + } + } + } +} + +/// Whether either the LHS or RHS of an ArithConstraint needs to be interned, +/// which can happen when unifying +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub enum NeedsInterning { + Lhs(ArithExpr), + Rhs(ArithExpr), + Neither, +} + +/// An arithmetic constraint, composed of the parameters from two Type::GenericArith's and an +/// optional NeedsInterning case. +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct ArithConstraint { + pub lhs: ArithId, + pub lhs_generics: Vec, + pub rhs: ArithId, + pub rhs_generics: Vec, + pub needs_interning: NeedsInterning, +} + +impl ArithConstraint { + pub(crate) fn evaluate_generics_to_u64( + generics: &[Type], + location: &Location, + interner: &NodeInterner, + ) -> Result, ArithExprError> { + generics + .iter() + .cloned() + .map(|generic| { + generic.evaluate_to_u64(location, interner).map(|result| (result, generic)) + }) + .collect::, _>>() + } + + pub fn validate(self, interner: &NodeInterner) -> Result<(), ArithConstraintError> { + let (lhs, rhs) = match &self.needs_interning { + NeedsInterning::Lhs(lhs_expr) => ( + (lhs_expr.clone(), Location::dummy()), + interner.get_arith_expression(self.rhs).clone(), + ), + NeedsInterning::Rhs(rhs_expr) => ( + interner.get_arith_expression(self.lhs).clone(), + (rhs_expr.clone(), Location::dummy()), + ), + NeedsInterning::Neither => ( + interner.get_arith_expression(self.lhs).clone(), + interner.get_arith_expression(self.rhs).clone(), + ), + }; + let (lhs_expr, lhs_location) = lhs; + let (rhs_expr, rhs_location) = rhs; + + // follow NamedGeneric bindings + let mut current_generic_index_offset = 0; + let (lhs_expr, lhs_new_generics) = + lhs_expr.follow_bindings(interner, &mut current_generic_index_offset); + let (rhs_expr, rhs_new_generics) = + rhs_expr.follow_bindings(interner, &mut current_generic_index_offset); + rhs_expr.offset_generic_indices(lhs_new_generics.len()); + + let lhs_generics: Vec<_> = self.lhs_generics.into_iter().chain(lhs_new_generics).collect(); + let rhs_generics: Vec<_> = self.rhs_generics.into_iter().chain(rhs_new_generics).collect(); + + match Self::evaluate_generics_to_u64(&lhs_generics, &lhs_location, interner).and_then( + |lhs_generics| { + let rhs_generics = + Self::evaluate_generics_to_u64(&rhs_generics, &rhs_location, interner)?; + Ok((lhs_generics, rhs_generics)) + }, + ) { + // all generics resolved + Ok((lhs_generics, rhs_generics)) => { + match ( + lhs_expr.evaluate(interner, &lhs_generics), + rhs_expr.evaluate(interner, &rhs_generics), + ) { + (Ok(lhs_evaluated), Ok(rhs_evaluated)) => { + if lhs_evaluated == rhs_evaluated { + Ok(()) + } else { + Err(ArithConstraintError::EvaluatedToDifferentValues { + lhs_evaluated, + rhs_evaluated, + location: rhs_location, + other_location: lhs_location, + }) + } + } + (lhs_result, rhs_result) => Err(ArithConstraintError::FailedToEvaluate { + lhs_expr, + rhs_expr, + lhs_result, + rhs_result, + location: lhs_location, + other_location: rhs_location, + }), + } + } + Err(arith_expr_error) => { + let mut fresh_bindings = TypeBindings::new(); + let generics_match = lhs_generics.iter().zip(rhs_generics.iter()).all( + |(lhs_generic, rhs_generic)| { + lhs_generic + .try_unify( + rhs_generic, + &mut fresh_bindings, + &interner.arith_constraints, + ) + .is_ok() + }, + ); + Type::apply_type_bindings(fresh_bindings); + + if generics_match { + // impute the unified lhs_generics into both ArithExpr's + let lhs_expr = lhs_expr.impute_variables(&lhs_generics).nf(); + let rhs_expr = rhs_expr.impute_variables(&lhs_generics).nf(); + if lhs_expr == rhs_expr { + Ok(()) + } else { + Err(ArithConstraintError::DistinctExpressions { + lhs_expr: lhs_expr.clone(), + rhs_expr: rhs_expr.clone(), + generics: lhs_generics.clone(), + location: lhs_location, + other_location: rhs_location, + }) + } + } else { + Err(ArithConstraintError::ArithExprError { + arith_expr_error, + location: lhs_location, + other_locations: vec![rhs_location], + }) + } + } + } + } +} + +pub type ArithConstraints = RefCell>; + +#[derive(Debug, thiserror::Error)] +pub enum ArithConstraintError { + UnresolvedGeneric { + generic: Type, + location: Location, + }, + EvaluatedToDifferentValues { + lhs_evaluated: u64, + rhs_evaluated: u64, + location: Location, + other_location: Location, + }, + FailedToEvaluate { + lhs_expr: ArithExpr, + rhs_expr: ArithExpr, + lhs_result: Result, + rhs_result: Result, + location: Location, + other_location: Location, + }, + DistinctExpressions { + lhs_expr: ArithExpr, + rhs_expr: ArithExpr, + generics: Vec, + location: Location, + other_location: Location, + }, + ArithExprError { + arith_expr_error: ArithExprError, + location: Location, + other_locations: Vec, + }, +} + +impl std::fmt::Display for ArithConstraintError { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::UnresolvedGeneric { generic, .. } => { + if let Type::NamedGeneric(_, name) = generic { + write!(f, "Unresolved generic value: {}", name) + } else { + write!(f, "Unresolved generic value: {}", generic) + } + } + Self::EvaluatedToDifferentValues { lhs_evaluated, rhs_evaluated, .. } => { + write!( + f, + "Generic arithmetic evaluated to different values: {} != {}", + lhs_evaluated, rhs_evaluated + ) + } + Self::FailedToEvaluate { lhs_expr, rhs_expr, lhs_result, rhs_result, .. } => { + if lhs_result.is_err() { + write!( + f, + "Left hand side of generic arithmetic failed to evaluate: {}\n{}\n", + lhs_expr, + lhs_result.as_ref().unwrap_err() + )?; + } + if rhs_result.is_err() { + write!( + f, + "Right hand side of generic arithmetic failed to evaluate: {}\n{}", + rhs_expr, + rhs_result.as_ref().unwrap_err() + )?; + } + assert!( + lhs_result.is_err() || rhs_result.is_err(), + "ArithConstraintError::FailedToEvaluate contains successful evaluation" + ); + Ok(()) + } + Self::DistinctExpressions { lhs_expr, rhs_expr, generics, .. } => { + write!(f, "Generic arithmetic appears to be distinct: {} != {}, where the arguments are: {:?}", lhs_expr, rhs_expr, generics) + } + Self::ArithExprError { arith_expr_error, .. } => arith_expr_error.fmt(f), + } + } +} + +impl ArithConstraintError { + pub fn location(&self) -> Location { + match self { + Self::UnresolvedGeneric { location, .. } + | Self::EvaluatedToDifferentValues { location, .. } + | Self::FailedToEvaluate { location, .. } + | Self::DistinctExpressions { location, .. } + | Self::ArithExprError { location, .. } => *location, + } + } + + pub fn other_locations(&self) -> Vec { + match self { + Self::UnresolvedGeneric { .. } => vec![], + + Self::EvaluatedToDifferentValues { other_location, .. } + | Self::FailedToEvaluate { other_location, .. } + | Self::DistinctExpressions { other_location, .. } => vec![*other_location], + + Self::ArithExprError { other_locations, .. } => other_locations.clone(), + } + } +} diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 99215c8f173..c72fc3ecfbc 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -1380,6 +1380,21 @@ fn deny_fold_attribute_on_unconstrained() { )); } +#[test] +fn basic_generic_arith() { + let src = r#" + fn foo(_x: [Field; N]) -> [Field; N+1] { + [0; N+1] + } + "#; + let errors = get_program_errors(src); + + if !errors.is_empty() { + dbg!(&errors); + } + assert_eq!(errors.len(), 0); +} + #[test] fn specify_function_types_with_turbofish() { let src = r#" diff --git a/test_programs/compile_failure/impl_on_generic_arith/Nargo.toml b/test_programs/compile_failure/impl_on_generic_arith/Nargo.toml new file mode 100644 index 00000000000..4450409f598 --- /dev/null +++ b/test_programs/compile_failure/impl_on_generic_arith/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "impl_on_generic_arith" +type = "bin" +authors = [""] +compiler_version = ">=0.29.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_failure/impl_on_generic_arith/src/main.nr b/test_programs/compile_failure/impl_on_generic_arith/src/main.nr new file mode 100644 index 00000000000..c8c8db6d17f --- /dev/null +++ b/test_programs/compile_failure/impl_on_generic_arith/src/main.nr @@ -0,0 +1,11 @@ +trait MyTrait { } + +struct WrappedArray { + inner: [T; N] +} + +// expected to fail with: +// error: impl's are not allowed on generic arithmetic +impl MyTrait for WrappedArray { } + +fn main() { } diff --git a/test_programs/compile_failure/impl_with_generic_arith_args/Nargo.toml b/test_programs/compile_failure/impl_with_generic_arith_args/Nargo.toml new file mode 100644 index 00000000000..33cc7d56ad7 --- /dev/null +++ b/test_programs/compile_failure/impl_with_generic_arith_args/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "impl_with_generic_arith_args" +type = "bin" +authors = [""] +compiler_version = ">=0.29.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_failure/impl_with_generic_arith_args/src/main.nr b/test_programs/compile_failure/impl_with_generic_arith_args/src/main.nr new file mode 100644 index 00000000000..a65ba910013 --- /dev/null +++ b/test_programs/compile_failure/impl_with_generic_arith_args/src/main.nr @@ -0,0 +1,9 @@ +struct WrappedArray { + inner: [T; N] +} + +trait MyTrait { } + +impl MyTrait for WrappedArray { } + +fn main() { } diff --git a/test_programs/compile_failure/unify_different_generic_arith/Nargo.toml b/test_programs/compile_failure/unify_different_generic_arith/Nargo.toml new file mode 100644 index 00000000000..c9482f18e22 --- /dev/null +++ b/test_programs/compile_failure/unify_different_generic_arith/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "unify_different_generic_arith" +type = "bin" +authors = [""] +compiler_version = ">=0.29.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_failure/unify_different_generic_arith/src/main.nr b/test_programs/compile_failure/unify_different_generic_arith/src/main.nr new file mode 100644 index 00000000000..34933271d5a --- /dev/null +++ b/test_programs/compile_failure/unify_different_generic_arith/src/main.nr @@ -0,0 +1,7 @@ +fn foo(_x: [Field; N]) -> [Field; N+1] { + [0; N+2] +} + +fn main() { + assert(0 == 0); +} diff --git a/test_programs/execution_success/generic_arith/Nargo.toml b/test_programs/execution_success/generic_arith/Nargo.toml new file mode 100644 index 00000000000..c9d6944b72f --- /dev/null +++ b/test_programs/execution_success/generic_arith/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "generic_arith" +type = "bin" +authors = [""] +compiler_version = ">=0.29.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/generic_arith/Prover.toml b/test_programs/execution_success/generic_arith/Prover.toml new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test_programs/execution_success/generic_arith/src/main.nr b/test_programs/execution_success/generic_arith/src/main.nr new file mode 100644 index 00000000000..a01acc0bc7d --- /dev/null +++ b/test_programs/execution_success/generic_arith/src/main.nr @@ -0,0 +1,176 @@ +// BEGIN basic function tests +fn foo(_x: [Field; N]) -> [Field; N + 1] { + [0; N + 1] +} + +fn foo_2(x: [Field; N]) -> [Field; N + 1] { + foo(x) +} + +fn bar(_x: [Field; N]) -> [Field; N - 1] { + [0; N - 1] +} + +fn tail(xs: [T; N]) -> [T; N - 1] { + let mut out: [T; N - 1] = [dep::std::unsafe::zeroed(); N - 1]; + for i in 1..N { + out[i-1] = xs[i]; + } + out +} + +fn backwards_tail(xs: [T; N + 1]) -> [T; N] { + let mut out: [T; N] = [dep::std::unsafe::zeroed(); N]; + for i in 1..(N+1) { + out[i-1] = xs[i]; + } + out +} + +fn cons(x: T, xs: [T; N]) -> [T; N + 1] { + let mut out: [T; N + 1] = [dep::std::unsafe::zeroed(); N + 1]; + out[0] = x; + for i in 1..(N+1) { + out[i] = xs[i-1]; + } + out +} + +fn split_first(x: [T; N]) -> (T, [T; N - 1]) { + (x[0], tail(x)) +} +// END basic function tests + +// BEGIN trait tests +struct WrappedArray { + inner: [T; N] +} + +impl WrappedArray { + fn split_first(self) -> (T, WrappedArray) { + let result = split_first(self.inner); + (result.0, WrappedArray { inner: result.1 }) + } +} + +impl Eq for WrappedArray where T: Eq { + fn eq(self, other: WrappedArray) -> bool { + self.inner.eq(other.inner) + } +} + +trait MyTrait { + fn my_trait_method(self) -> [Field; 3] { + [0; 3] + } +} + +impl MyTrait<0> for () { } +// END trait tests + +// BEGIN serialize test +trait Serialize { + fn serialize(self) -> [Field; N]; +} + +impl Serialize<0> for () { + fn serialize(self) -> [Field; 0] { + [] + } +} + +impl Serialize<1> for Field { + fn serialize(self) -> [Field; 1] { + [self] + } +} + +struct Foo { + a: T, + b: T, +} + +// error: impl's are not allowed to have generic arithmetic arguments +// +// impl Serialize for Foo { +// fn serialize_foo(x: Foo) -> [Field; N * 2] where T: Serialize { +// [0; N * 2] +// } +// } +fn serialize_foo(x: Foo) -> [Field; N * 2] where T: Serialize { + let out_a: [Field; N] = x.a.serialize(); + let out_b: [Field; N] = x.b.serialize(); + let mut result: [Field; N * 2] = [0; N * 2]; + for i in 0..out_a.len() { + result[i] = out_a[i]; + result[out_a.len() + i] = out_b[i]; + } + result +} +// END serialize test + +fn main() { + let input_array: [Field; 2] = [1, 2]; + let foo_result: [Field; 3] = foo(input_array); + assert(foo_result == [0, 0, 0]); + + + // tail + let input_array: [Field; 3] = [1, 2, 3]; + let result: [Field; 2] = tail(input_array); + assert(result == [2, 3]); + + let result: [Field; 1] = tail(result); + assert(result == [3]); + + + // backwards_tail + let input_array: [Field; 3] = [1, 2, 3]; + let result: [Field; 2] = backwards_tail(input_array); + assert(result == [2, 3]); + + let result: [Field; 1] = backwards_tail(result); + assert(result == [3]); + + + // cons + let input_array: [Field; 0] = []; + let result: [Field; 1] = cons(2, input_array); + assert(result == [2]); + + // NOTE: fixed in the elaborator + // + // unexpected error when running these in one line: + // + // error: Failed to prove generic arithmetic equivalent: + // Generic arithmetic evaluated differently: Ok(2) != Err(UnboundVariable { .. }) + // test_programs/compile_success_empty/generic_arith/src/main.nr:15:22 + // assert(tail([1,2,3]) == [2, 3]); + + // it also fails when using turbofish + // + // assert(tail::([1,2,3]) == [2, 3]); + + // it also fails when compared to another function output + // + // assert(tail([1,2,3]) == tail([2, 2, 3])); + + let wrapped_array = WrappedArray { inner: [1, 2, 3] }; + let result: (Field, WrappedArray) = wrapped_array.split_first(); + let expected_result: (Field, WrappedArray) = (1, WrappedArray { inner: [2, 3] }); + assert(result == expected_result); + + // this fails with what appears to be the same error, + // viz. that the location is in the _function_ that has already type-checked: + // + // error: Failed to prove generic arithmetic equivalent: + // Generic arithmetic evaluated differently: Ok(2) != Err(UnboundVariable { .. }) + // test_programs/compile_success_empty/generic_arith/src/main.nr:31:49 + // + // assert(wrapped_array.split_first() == (1, WrappedArray { inner: [2, 3] })); + + // serialize test + let test_foo = Foo { a: 2, b: 3 }; + let result: [Field; 2] = serialize_foo(test_foo); + assert(result == [2, 3]); +} diff --git a/tooling/debugger/ignored-tests.txt b/tooling/debugger/ignored-tests.txt index a9193896589..2d2c36cabbb 100644 --- a/tooling/debugger/ignored-tests.txt +++ b/tooling/debugger/ignored-tests.txt @@ -10,8 +10,9 @@ fold_complex_outputs fold_distinct_return fold_fibonacci fold_numeric_generic_poseidon +generic_arith is_unconstrained modulus references regression_4709 -to_bytes_integration \ No newline at end of file +to_bytes_integration diff --git a/tooling/noirc_abi/Cargo.toml b/tooling/noirc_abi/Cargo.toml index baae2dfa35e..20052400278 100644 --- a/tooling/noirc_abi/Cargo.toml +++ b/tooling/noirc_abi/Cargo.toml @@ -11,6 +11,7 @@ license.workspace = true [dependencies] acvm.workspace = true iter-extended.workspace = true +noirc_errors.workspace = true noirc_frontend.workspace = true noirc_printable_type.workspace = true toml.workspace = true diff --git a/tooling/noirc_abi/src/lib.rs b/tooling/noirc_abi/src/lib.rs index 0acace71fb3..a029a1c03e8 100644 --- a/tooling/noirc_abi/src/lib.rs +++ b/tooling/noirc_abi/src/lib.rs @@ -13,7 +13,9 @@ use acvm::{ use errors::AbiError; use input_parser::InputValue; use iter_extended::{try_btree_map, try_vecmap, vecmap}; +use noirc_errors::Location; use noirc_frontend::ast::{Signedness, Visibility}; +use noirc_frontend::node_interner::NodeInterner; use noirc_frontend::{hir::Context, Type, TypeBinding, TypeVariableKind}; use noirc_printable_type::{ decode_value as printable_type_decode_value, PrintableType, PrintableValue, @@ -131,13 +133,14 @@ pub enum Sign { impl AbiType { pub fn from_type(context: &Context, typ: &Type) -> Self { + let dummy_interner = NodeInterner::default(); // Note; use strict_eq instead of partial_eq when comparing field types // in this method, you most likely want to distinguish between public and private match typ { Type::FieldElement => Self::Field, Type::Array(size, typ) => { let length = size - .evaluate_to_u64() + .evaluate_to_u64(&Location::dummy(), &dummy_interner) .expect("Cannot have variable sized arrays as a parameter to main"); let typ = typ.as_ref(); Self::Array { length, typ: Box::new(Self::from_type(context, typ)) } @@ -160,7 +163,7 @@ impl AbiType { Type::Bool => Self::Boolean, Type::String(size) => { let size = size - .evaluate_to_u64() + .evaluate_to_u64(&Location::dummy(), &dummy_interner) .expect("Cannot have variable sized strings as a parameter to main"); Self::String { length: size } } @@ -186,6 +189,7 @@ impl AbiType { | Type::TypeVariable(_, _) | Type::NamedGeneric(..) | Type::Forall(..) + | Type::GenericArith(..) | Type::Code | Type::Slice(_) | Type::Function(_, _, _) => unreachable!("{typ} cannot be used in the abi"), @@ -607,7 +611,10 @@ impl AbiErrorType { pub fn from_type(context: &Context, typ: &Type) -> Self { match typ { Type::FmtString(len, item_types) => { - let length = len.evaluate_to_u64().expect("Cannot evaluate fmt length"); + let dummy_interner = NodeInterner::default(); + let length = len + .evaluate_to_u64(&Location::dummy(), &dummy_interner) + .expect("Cannot evaluate fmt length"); let Type::Tuple(item_types) = item_types.as_ref() else { unreachable!("FmtString items must be a tuple") };