diff --git a/compiler/noirc_frontend/src/ast/expression.rs b/compiler/noirc_frontend/src/ast/expression.rs index 15b781051bc..01705e0af88 100644 --- a/compiler/noirc_frontend/src/ast/expression.rs +++ b/compiler/noirc_frontend/src/ast/expression.rs @@ -245,6 +245,45 @@ impl Expression { pub fn new(kind: ExpressionKind, span: Span) -> Expression { Expression { kind, span } } + + /// Returns the innermost span that gives this expression its type. + pub fn type_span(&self) -> Span { + match &self.kind { + ExpressionKind::Block(block_expression) + | ExpressionKind::Comptime(block_expression, _) + | ExpressionKind::Unsafe(block_expression, _) => { + if let Some(statement) = block_expression.statements.last() { + statement.type_span() + } else { + self.span + } + } + ExpressionKind::Parenthesized(expression) => expression.type_span(), + ExpressionKind::Literal(..) + | ExpressionKind::Prefix(..) + | ExpressionKind::Index(..) + | ExpressionKind::Call(..) + | ExpressionKind::MethodCall(..) + | ExpressionKind::Constrain(..) + | ExpressionKind::Constructor(..) + | ExpressionKind::MemberAccess(..) + | ExpressionKind::Cast(..) + | ExpressionKind::Infix(..) + | ExpressionKind::If(..) + | ExpressionKind::Match(..) + | ExpressionKind::Variable(..) + | ExpressionKind::Tuple(..) + | ExpressionKind::Lambda(..) + | ExpressionKind::Quote(..) + | ExpressionKind::Unquote(..) + | ExpressionKind::AsTraitPath(..) + | ExpressionKind::TypePath(..) + | ExpressionKind::Resolved(..) + | ExpressionKind::Interned(..) + | ExpressionKind::InternedStatement(..) + | ExpressionKind::Error => self.span, + } + } } pub type BinaryOp = Spanned; diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 145cff0a341..372f20f8780 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -73,6 +73,24 @@ impl Statement { self.kind = self.kind.add_semicolon(semi, span, last_statement_in_block, emit_error); self } + + /// Returns the innermost span that gives this statement its type. + pub fn type_span(&self) -> Span { + match &self.kind { + StatementKind::Expression(expression) => expression.type_span(), + StatementKind::Comptime(statement) => statement.type_span(), + StatementKind::Let(..) + | StatementKind::Assign(..) + | StatementKind::For(..) + | StatementKind::Loop(..) + | StatementKind::While(..) + | StatementKind::Break + | StatementKind::Continue + | StatementKind::Semi(..) + | StatementKind::Interned(..) + | StatementKind::Error => self.span, + } + } } impl StatementKind { diff --git a/compiler/noirc_frontend/src/elaborator/enums.rs b/compiler/noirc_frontend/src/elaborator/enums.rs index ffc2f022120..02d9bfae494 100644 --- a/compiler/noirc_frontend/src/elaborator/enums.rs +++ b/compiler/noirc_frontend/src/elaborator/enums.rs @@ -274,7 +274,7 @@ impl Elaborator<'_> { let columns = vec![Column::new(variable_to_match, pattern)]; let guard = None; - let body_span = branch.span; + let body_span = branch.type_span(); let (body, body_type) = self.elaborate_expression(branch); self.unify(&body_type, &result_type, || TypeCheckError::TypeMismatch { @@ -291,7 +291,7 @@ impl Elaborator<'_> { /// Convert an expression into a Pattern, defining any variables within. fn expression_to_pattern(&mut self, expression: Expression, expected_type: &Type) -> Pattern { - let expr_span = expression.span; + let expr_span = expression.type_span(); let unify_with_expected_type = |this: &mut Self, actual| { this.unify(actual, expected_type, || TypeCheckError::TypeMismatch { expected_typ: expected_type.to_string(), diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 18d5e3be82e..3b25f85a25c 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -971,8 +971,8 @@ impl<'context> Elaborator<'context> { if_expr: IfExpression, target_type: Option<&Type>, ) -> (HirExpression, Type) { - let expr_span = if_expr.condition.span; - let consequence_span = if_expr.consequence.span; + let expr_span = if_expr.condition.type_span(); + let consequence_span = if_expr.consequence.type_span(); let (condition, cond_type) = self.elaborate_expression(if_expr.condition); let (consequence, mut ret_type) = self.elaborate_expression_with_target_type(if_expr.consequence, target_type); @@ -984,9 +984,10 @@ impl<'context> Elaborator<'context> { }); let (alternative, else_type, error_span) = if let Some(alternative) = if_expr.alternative { + let alternative_span = alternative.type_span(); let (else_, else_type) = self.elaborate_expression_with_target_type(alternative, target_type); - (Some(else_), else_type, expr_span) + (Some(else_), else_type, alternative_span) } else { (None, Type::Unit, consequence_span) }; diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index 8b60a660a16..3379db4aa66 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -219,7 +219,14 @@ impl<'context> Elaborator<'context> { self.interner.push_definition_type(identifier.id, start_range_type); - let (block, _block_type) = self.elaborate_expression(block); + let block_span = block.type_span(); + let (block, block_type) = self.elaborate_expression(block); + + self.unify(&block_type, &Type::Unit, || TypeCheckError::TypeMismatch { + expected_typ: Type::Unit.to_string(), + expr_typ: block_type.to_string(), + expr_span: block_span, + }); self.pop_scope(); self.current_loop = old_loop; @@ -244,7 +251,14 @@ impl<'context> Elaborator<'context> { self.current_loop = Some(Loop { is_for: false, has_break: false }); self.push_scope(); - let (block, _block_type) = self.elaborate_expression(block); + let block_span = block.type_span(); + let (block, block_type) = self.elaborate_expression(block); + + self.unify(&block_type, &Type::Unit, || TypeCheckError::TypeMismatch { + expected_typ: Type::Unit.to_string(), + expr_typ: block_type.to_string(), + expr_span: block_span, + }); self.pop_scope(); @@ -269,9 +283,8 @@ impl<'context> Elaborator<'context> { self.current_loop = Some(Loop { is_for: false, has_break: false }); self.push_scope(); - let condition_span = while_.condition.span; + let condition_span = while_.condition.type_span(); let (condition, cond_type) = self.elaborate_expression(while_.condition); - let (block, _block_type) = self.elaborate_expression(while_.body); self.unify(&cond_type, &Type::Bool, || TypeCheckError::TypeMismatch { expected_typ: Type::Bool.to_string(), @@ -279,6 +292,15 @@ impl<'context> Elaborator<'context> { expr_span: condition_span, }); + let block_span = while_.body.type_span(); + let (block, block_type) = self.elaborate_expression(while_.body); + + self.unify(&block_type, &Type::Unit, || TypeCheckError::TypeMismatch { + expected_typ: Type::Unit.to_string(), + expr_typ: block_type.to_string(), + expr_span: block_span, + }); + self.pop_scope(); std::mem::replace(&mut self.current_loop, old_loop).expect("Expected a loop"); diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 1b3a19a5cfc..e9af0f64bc0 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -4371,3 +4371,56 @@ fn does_not_stack_overflow_on_many_comments_in_a_row() { let src = "//\n".repeat(10_000); assert_no_errors(&src); } + +#[test] +fn errors_if_for_body_type_is_not_unit() { + let src = r#" + fn main() { + for _ in 0..1 { + 1 + } + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::TypeError(TypeCheckError::TypeMismatch { .. }) = &errors[0].0 else { + panic!("Expected a TypeMismatch error"); + }; +} + +#[test] +fn errors_if_loop_body_type_is_not_unit() { + let src = r#" + unconstrained fn main() { + loop { + if false { break; } + + 1 + } + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::TypeError(TypeCheckError::TypeMismatch { .. }) = &errors[0].0 else { + panic!("Expected a TypeMismatch error"); + }; +} + +#[test] +fn errors_if_while_body_type_is_not_unit() { + let src = r#" + unconstrained fn main() { + while 1 == 1 { + 1 + } + } + "#; + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + let CompilationError::TypeError(TypeCheckError::TypeMismatch { .. }) = &errors[0].0 else { + panic!("Expected a TypeMismatch error"); + }; +}