diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index f1af1371cad..e853615b3db 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -601,77 +601,6 @@ impl Elaborator<'_> { Some((let_, ident_id)) } - /// Type checks a field access, adding dereference operators as necessary - pub(super) fn check_field_access( - &mut self, - lhs_type: &Type, - field_name: &str, - location: Location, - dereference_lhs: Option, - ) -> Option<(Type, usize)> { - let lhs_type = lhs_type.follow_bindings(); - - match &lhs_type { - Type::DataType(s, args) => { - let s = s.borrow(); - if let Some((field, visibility, index)) = s.get_field(field_name, args) { - self.interner.add_struct_member_reference(s.id, index, location); - - self.check_struct_field_visibility(&s, field_name, visibility, location); - - return Some((field, index)); - } - } - Type::Tuple(elements) => { - if let Ok(index) = field_name.parse::() { - let length = elements.len(); - if index < length { - return Some((elements[index].clone(), index)); - } else { - self.push_err(TypeCheckError::TupleIndexOutOfBounds { - index, - lhs_type, - length, - location, - }); - return None; - } - } - } - // If the lhs is a reference we automatically transform `lhs.field` into `(*lhs).field` - Type::Reference(element, mutable) => { - if let Some(mut dereference_lhs) = dereference_lhs { - dereference_lhs(self, lhs_type.clone(), element.as_ref().clone()); - return self.check_field_access( - element, - field_name, - location, - Some(dereference_lhs), - ); - } else { - let (element, index) = - self.check_field_access(element, field_name, location, dereference_lhs)?; - return Some((Type::Reference(Box::new(element), *mutable), index)); - } - } - _ => (), - } - - // If we get here the type has no field named 'access.rhs'. - // Now we specialize the error message based on whether we know the object type in question yet. - if let Type::TypeVariable(..) = &lhs_type { - self.push_err(TypeCheckError::TypeAnnotationsNeededForFieldAccess { location }); - } else if lhs_type != Type::Error { - self.push_err(TypeCheckError::AccessUnknownMember { - lhs_type, - field_name: field_name.to_string(), - location, - }); - } - - None - } - fn elaborate_comptime_statement(&mut self, statement: Statement) -> (HirStatement, Type) { let location = statement.location; let (hir_statement, _typ) = diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index fd0b5cca45a..a7d10c6a049 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -1699,6 +1699,77 @@ impl Elaborator<'_> { } } + /// Type checks a field access, adding dereference operators as necessary + pub(super) fn check_field_access( + &mut self, + lhs_type: &Type, + field_name: &str, + location: Location, + dereference_lhs: Option, + ) -> Option<(Type, usize)> { + let lhs_type = lhs_type.follow_bindings(); + + match &lhs_type { + Type::DataType(s, args) => { + let s = s.borrow(); + if let Some((field, visibility, index)) = s.get_field(field_name, args) { + self.interner.add_struct_member_reference(s.id, index, location); + + self.check_struct_field_visibility(&s, field_name, visibility, location); + + return Some((field, index)); + } + } + Type::Tuple(elements) => { + if let Ok(index) = field_name.parse::() { + let length = elements.len(); + if index < length { + return Some((elements[index].clone(), index)); + } else { + self.push_err(TypeCheckError::TupleIndexOutOfBounds { + index, + lhs_type, + length, + location, + }); + return None; + } + } + } + // If the lhs is a reference we automatically transform `lhs.field` into `(*lhs).field` + Type::Reference(element, mutable) => { + if let Some(mut dereference_lhs) = dereference_lhs { + dereference_lhs(self, lhs_type.clone(), element.as_ref().clone()); + return self.check_field_access( + element, + field_name, + location, + Some(dereference_lhs), + ); + } else { + let (element, index) = + self.check_field_access(element, field_name, location, dereference_lhs)?; + return Some((Type::Reference(Box::new(element), *mutable), index)); + } + } + _ => (), + } + + // If we get here the type has no field named 'access.rhs'. + // Now we specialize the error message based on whether we know the object type in question yet. + if let Type::TypeVariable(..) = &lhs_type { + self.push_err(TypeCheckError::TypeAnnotationsNeededForFieldAccess { location }); + } else if lhs_type != Type::Error { + self.push_err(TypeCheckError::AccessUnknownMember { + lhs_type, + field_name: field_name.to_string(), + location, + }); + } + + None + } + /// Try to look up a method on a [Type] by name: /// * if the object type is generic, look it up in the trait constraints /// * otherwise look it up directly on the type, or in traits the type implements diff --git a/compiler/noirc_frontend/src/tests/assignment.rs b/compiler/noirc_frontend/src/tests/assignment.rs index 9b25001a2c3..2be55e4bd16 100644 --- a/compiler/noirc_frontend/src/tests/assignment.rs +++ b/compiler/noirc_frontend/src/tests/assignment.rs @@ -437,6 +437,33 @@ fn dereference_in_lvalue() { assert_no_errors(src); } +#[test] +fn reference_chain_in_tuple_member_access() { + // Ensure that references appearing in the middle of a member access chain are properly dereferenced. + // + // 1. `x` has type `(&mut (u32, &mut (u32, u32)), u32)` + // 2. Accessing `.0` yields `&mut (u32, &mut (u32, u32))` + // 3. Must dereference to get `(u32, &mut (u32, u32))` + // 4. Accessing `.1` yields `&mut (u32, u32)` + // 5. Must dereference to get `(u32, u32)` + // 6. Accessing `.0` yields `u32` + let src = r#" + fn main() { + let inner = &mut (10, 20); + let outer = &mut (5, inner); + let mut x = (outer, 99); + + x.0.1.0 = 42; + + assert(x.0.1.0 == 42); + assert(x.0.1.1 == 20); + assert(x.0.0 == 5); + assert(x.1 == 99); + } + "#; + assert_no_errors(src); +} + #[test] fn mut_comptime_variable_in_runtime() { let src = r#"