diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs index 3a73c3eb09a..995ab5bd86e 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter.rs @@ -241,6 +241,8 @@ impl<'a> Interpreter<'a> { Ok(()) } HirPattern::Mutable(pattern, _) => { + // Create a mutable reference to store to + let argument = Value::Pointer(Shared::new(argument), true); self.define_pattern(pattern, typ, argument, location) } HirPattern::Tuple(pattern_fields, _) => match (argument, typ) { @@ -334,8 +336,19 @@ impl<'a> Interpreter<'a> { } } - /// Evaluate an expression and return the result + /// Evaluate an expression and return the result. + /// This will automatically dereference a mutable variable if used. pub fn evaluate(&mut self, id: ExprId) -> IResult { + match self.evaluate_no_dereference(id)? { + Value::Pointer(elem, true) => Ok(elem.borrow().clone()), + other => Ok(other), + } + } + + /// Evaluating a mutable variable will dereference it automatically. + /// This function should be used when that is not desired - e.g. when + /// compiling a `&mut var` expression to grab the original reference. + fn evaluate_no_dereference(&mut self, id: ExprId) -> IResult { match self.interner.expression(&id) { HirExpression::Ident(ident, _) => self.evaluate_ident(ident, id), HirExpression::Literal(literal) => self.evaluate_literal(literal, id), @@ -592,7 +605,10 @@ impl<'a> Interpreter<'a> { } fn evaluate_prefix(&mut self, prefix: HirPrefixExpression, id: ExprId) -> IResult { - let rhs = self.evaluate(prefix.rhs)?; + let rhs = match prefix.operator { + UnaryOp::MutableReference => self.evaluate_no_dereference(prefix.rhs)?, + _ => self.evaluate(prefix.rhs)?, + }; self.evaluate_prefix_with_value(rhs, prefix.operator, id) } @@ -634,9 +650,17 @@ impl<'a> Interpreter<'a> { Err(InterpreterError::InvalidValueForUnary { value, location, operator: "not" }) } }, - UnaryOp::MutableReference => Ok(Value::Pointer(Shared::new(rhs))), + UnaryOp::MutableReference => { + // If this is a mutable variable (auto_deref = true), turn this into an explicit + // mutable reference just by switching the value of `auto_deref`. Otherwise, wrap + // the value in a fresh reference. + match rhs { + Value::Pointer(elem, true) => Ok(Value::Pointer(elem, false)), + other => Ok(Value::Pointer(Shared::new(other), false)), + } + } UnaryOp::Dereference { implicitly_added: _ } => match rhs { - Value::Pointer(element) => Ok(element.borrow().clone()), + Value::Pointer(element, _) => Ok(element.borrow().clone()), value => { let location = self.interner.expr_location(&id); Err(InterpreterError::NonPointerDereferenced { value, location }) @@ -1300,12 +1324,10 @@ impl<'a> Interpreter<'a> { fn store_lvalue(&mut self, lvalue: HirLValue, rhs: Value) -> IResult<()> { match lvalue { - HirLValue::Ident(ident, typ) => { - self.mutate(ident.id, rhs, ident.location) - } + HirLValue::Ident(ident, typ) => self.mutate(ident.id, rhs, ident.location), HirLValue::Dereference { lvalue, element_type: _, location } => { match self.evaluate_lvalue(&lvalue)? { - Value::Pointer(value) => { + Value::Pointer(value, _) => { *value.borrow_mut() = rhs; Ok(()) } @@ -1355,10 +1377,13 @@ impl<'a> Interpreter<'a> { fn evaluate_lvalue(&mut self, lvalue: &HirLValue) -> IResult { match lvalue { - HirLValue::Ident(ident, _) => self.lookup(ident), + HirLValue::Ident(ident, _) => match self.lookup(ident)? { + Value::Pointer(elem, true) => Ok(elem.borrow().clone()), + other => Ok(other), + }, HirLValue::Dereference { lvalue, element_type: _, location } => { match self.evaluate_lvalue(lvalue)? { - Value::Pointer(value) => Ok(value.borrow().clone()), + Value::Pointer(value, _) => Ok(value.borrow().clone()), value => { Err(InterpreterError::NonPointerDereferenced { value, location: *location }) } diff --git a/compiler/noirc_frontend/src/hir/comptime/tests.rs b/compiler/noirc_frontend/src/hir/comptime/tests.rs index e8e05506c94..6fdd956caf6 100644 --- a/compiler/noirc_frontend/src/hir/comptime/tests.rs +++ b/compiler/noirc_frontend/src/hir/comptime/tests.rs @@ -77,6 +77,18 @@ fn mutating_mutable_references() { assert_eq!(result, Value::I64(4)); } +#[test] +fn mutation_leaks() { + let program = "comptime fn main() -> pub i8 { + let mut x = 3; + let y = &mut x; + *y = 5; + x + }"; + let result = interpret(program, vec!["main".into()]); + assert_eq!(result, Value::I8(5)); +} + #[test] fn mutating_arrays() { let program = "comptime fn main() -> pub u8 { diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 0d42f7aae54..9eeb323d664 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -41,7 +41,7 @@ pub enum Value { Closure(HirLambda, Vec, Type), Tuple(Vec), Struct(HashMap, Value>, Type), - Pointer(Shared), + Pointer(Shared, /* auto_deref */ bool), Array(Vector, Type), Slice(Vector, Type), Code(Rc), @@ -80,7 +80,7 @@ impl Value { Value::Slice(_, typ) => return Cow::Borrowed(typ), Value::Code(_) => Type::Quoted(QuotedType::Quoted), Value::StructDefinition(_) => Type::Quoted(QuotedType::StructDefinition), - Value::Pointer(element) => { + Value::Pointer(element, _) => { let element = element.borrow().get_type().into_owned(); Type::MutableReference(Box::new(element)) } @@ -201,7 +201,7 @@ impl Value { } }; } - Value::Pointer(_) + Value::Pointer(..) | Value::StructDefinition(_) | Value::TraitConstraint(_) | Value::TraitDefinition(_) @@ -312,7 +312,7 @@ impl Value { HirExpression::Literal(HirLiteral::Slice(HirArrayLiteral::Standard(elements))) } Value::Code(block) => HirExpression::Unquote(unwrap_rc(block)), - Value::Pointer(_) + Value::Pointer(..) | Value::StructDefinition(_) | Value::TraitConstraint(_) | Value::TraitDefinition(_) @@ -404,7 +404,7 @@ impl Display for Value { let fields = vecmap(fields, |(name, value)| format!("{}: {}", name, value)); write!(f, "{typename} {{ {} }}", fields.join(", ")) } - Value::Pointer(value) => write!(f, "&mut {}", value.borrow()), + Value::Pointer(value, _) => write!(f, "&mut {}", value.borrow()), Value::Array(values, _) => { let values = vecmap(values, ToString::to_string); write!(f, "[{}]", values.join(", ")) diff --git a/compiler/noirc_frontend/src/parser/parser/structs.rs b/compiler/noirc_frontend/src/parser/parser/structs.rs index 3055e9904df..9a3adf74d7f 100644 --- a/compiler/noirc_frontend/src/parser/parser/structs.rs +++ b/compiler/noirc_frontend/src/parser/parser/structs.rs @@ -36,7 +36,14 @@ pub(super) fn struct_definition() -> impl NoirParser { .then(fields) .validate(|((((attributes, is_comptime), name), generics), fields), span, emit| { let attributes = validate_secondary_attributes(attributes, span, emit); - TopLevelStatement::Struct(NoirStruct { name, attributes, generics, fields, span, is_comptime }) + TopLevelStatement::Struct(NoirStruct { + name, + attributes, + generics, + fields, + span, + is_comptime, + }) }) } diff --git a/noir_stdlib/src/meta/trait_constraint.nr b/noir_stdlib/src/meta/trait_constraint.nr index 002a3cf4e75..f0276608974 100644 --- a/noir_stdlib/src/meta/trait_constraint.nr +++ b/noir_stdlib/src/meta/trait_constraint.nr @@ -1,4 +1,4 @@ -use crate::hash::{ Hash, Hasher }; +use crate::hash::{Hash, Hasher}; use crate::cmp::Eq; impl Eq for TraitConstraint { diff --git a/test_programs/compile_success_empty/comptime_trait_constraint/src/main.nr b/test_programs/compile_success_empty/comptime_trait_constraint/src/main.nr index c493b9c6978..5c99f8c587e 100644 --- a/test_programs/compile_success_empty/comptime_trait_constraint/src/main.nr +++ b/test_programs/compile_success_empty/comptime_trait_constraint/src/main.nr @@ -1,11 +1,12 @@ -use std::hash::{ Hash, Hasher }; +use std::hash::{Hash, Hasher}; trait TraitWithGenerics { fn foo(self) -> (A, B); } fn main() { - comptime { + comptime + { let constraint1 = quote { Default }.as_trait_constraint(); let constraint2 = quote { TraitWithGenerics }.as_trait_constraint(); @@ -29,13 +30,10 @@ comptime struct TestHasher { comptime impl Hasher for TestHasher { comptime fn finish(self) -> Field { - println(self.result); self.result } comptime fn write(&mut self, input: Field) { - println(self.result); self.result += input; - println(self.result); } }