diff --git a/compiler/noirc_evaluator/src/ssa/interpreter/mod.rs b/compiler/noirc_evaluator/src/ssa/interpreter/mod.rs index 4905f72eb4f..f16161b1310 100644 --- a/compiler/noirc_evaluator/src/ssa/interpreter/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/interpreter/mod.rs @@ -94,12 +94,15 @@ impl<'ssa> Interpreter<'ssa> { &self.call_stack.first().expect("call_stack should always be non-empty").scope } - fn current_function(&self) -> &'ssa Function { + fn try_current_function(&self) -> Option<&'ssa Function> { let current_function_id = self.call_context().called_function; - let current_function_id = current_function_id.expect( + current_function_id.map(|current_function_id| &self.ssa.functions[¤t_function_id]) + } + + fn current_function(&self) -> &'ssa Function { + self.try_current_function().expect( "Tried calling `Interpreter::current_function` while evaluating global instructions", - ); - &self.ssa.functions[¤t_function_id] + ) } fn dfg(&self) -> &'ssa DataFlowGraph { @@ -376,9 +379,17 @@ impl<'ssa> Interpreter<'ssa> { try_vecmap(ids, |id| self.lookup(*id)) } - fn side_effects_enabled(&self) -> bool { - match self.current_function().runtime() { - RuntimeType::Acir(_) => self.side_effects_enabled, + fn side_effects_enabled(&self, instruction: &Instruction) -> bool { + let Some(current_function) = self.try_current_function() else { + // If there's no current function it means we are evaluating global instructions + return true; + }; + + match current_function.runtime() { + RuntimeType::Acir(_) => { + self.side_effects_enabled + || !instruction.requires_acir_gen_predicate(¤t_function.dfg) + } RuntimeType::Brillig(_) => true, } } @@ -389,9 +400,11 @@ impl<'ssa> Interpreter<'ssa> { instruction: &Instruction, results: &[ValueId], ) -> IResult<()> { + let side_effects_enabled = self.side_effects_enabled(instruction); + match instruction { Instruction::Binary(binary) => { - let result = self.interpret_binary(binary)?; + let result = self.interpret_binary(binary, side_effects_enabled)?; self.define(results[0], result); Ok(()) } @@ -410,7 +423,7 @@ impl<'ssa> Interpreter<'ssa> { Instruction::Constrain(lhs_id, rhs_id, constrain_error) => { let lhs = self.lookup(*lhs_id)?; let rhs = self.lookup(*rhs_id)?; - if self.side_effects_enabled() && lhs != rhs { + if side_effects_enabled && lhs != rhs { let lhs = lhs.to_string(); let rhs = rhs.to_string(); let lhs_id = *lhs_id; @@ -433,7 +446,7 @@ impl<'ssa> Interpreter<'ssa> { Instruction::ConstrainNotEqual(lhs_id, rhs_id, constrain_error) => { let lhs = self.lookup(*lhs_id)?; let rhs = self.lookup(*rhs_id)?; - if self.side_effects_enabled() && lhs == rhs { + if side_effects_enabled && lhs == rhs { let lhs = lhs.to_string(); let rhs = rhs.to_string(); let lhs_id = *lhs_id; @@ -453,10 +466,16 @@ impl<'ssa> Interpreter<'ssa> { } Ok(()) } - Instruction::RangeCheck { value, max_bit_size, assert_message } => { - self.interpret_range_check(*value, *max_bit_size, assert_message.as_ref()) + Instruction::RangeCheck { value, max_bit_size, assert_message } => self + .interpret_range_check( + *value, + *max_bit_size, + assert_message.as_ref(), + side_effects_enabled, + ), + Instruction::Call { func, arguments } => { + self.interpret_call(*func, arguments, results, side_effects_enabled) } - Instruction::Call { func, arguments } => self.interpret_call(*func, arguments, results), Instruction::Allocate => { self.interpret_allocate(results[0]); Ok(()) @@ -468,11 +487,18 @@ impl<'ssa> Interpreter<'ssa> { Ok(()) } Instruction::ArrayGet { array, index, offset } => { - self.interpret_array_get(*array, *index, *offset, results[0]) - } - Instruction::ArraySet { array, index, value, mutable, offset } => { - self.interpret_array_set(*array, *index, *value, *mutable, *offset, results[0]) + self.interpret_array_get(*array, *index, *offset, results[0], side_effects_enabled) } + Instruction::ArraySet { array, index, value, mutable, offset } => self + .interpret_array_set( + *array, + *index, + *value, + *mutable, + *offset, + results[0], + side_effects_enabled, + ), Instruction::IncrementRc { value } => self.interpret_inc_rc(*value), Instruction::DecrementRc { value } => self.interpret_dec_rc(*value), Instruction::IfElse { then_condition, then_value, else_condition, else_value } => self @@ -556,8 +582,9 @@ impl<'ssa> Interpreter<'ssa> { value_id: ValueId, max_bit_size: u32, error_message: Option<&String>, + side_effects_enabled: bool, ) -> IResult<()> { - if !self.side_effects_enabled() { + if !side_effects_enabled { return Ok(()); } @@ -628,11 +655,12 @@ impl<'ssa> Interpreter<'ssa> { function_id: ValueId, argument_ids: &[ValueId], results: &[ValueId], + side_effects_enabled: bool, ) -> IResult<()> { let function = self.lookup(function_id)?; let mut arguments = try_vecmap(argument_ids, |argument| self.lookup(*argument))?; - let new_results = if self.side_effects_enabled() { + let new_results = if side_effects_enabled { match function { Value::Function(id) => { // If we're crossing a constrained -> unconstrained boundary we have to wipe @@ -762,8 +790,9 @@ impl<'ssa> Interpreter<'ssa> { index: ValueId, offset: ArrayOffset, result: ValueId, + side_effects_enabled: bool, ) -> IResult<()> { - let element = if self.side_effects_enabled() { + let element = if side_effects_enabled { let array = self.lookup_array_or_slice(array, "array get")?; let index = self.lookup_u32(index, "array get index")?; let index = index - offset.to_u32(); @@ -776,6 +805,7 @@ impl<'ssa> Interpreter<'ssa> { Ok(()) } + #[allow(clippy::too_many_arguments)] fn interpret_array_set( &mut self, array: ValueId, @@ -784,10 +814,11 @@ impl<'ssa> Interpreter<'ssa> { mutable: bool, offset: ArrayOffset, result: ValueId, + side_effects_enabled: bool, ) -> IResult<()> { let array = self.lookup_array_or_slice(array, "array set")?; - let result_array = if self.side_effects_enabled() { + let result_array = if side_effects_enabled { let index = self.lookup_u32(index, "array set index")?; let index = index - offset.to_u32(); let value = self.lookup(value)?; @@ -1047,7 +1078,7 @@ macro_rules! apply_int_comparison_op { } impl Interpreter<'_> { - fn interpret_binary(&mut self, binary: &Binary) -> IResult { + fn interpret_binary(&mut self, binary: &Binary, side_effects_enabled: bool) -> IResult { let lhs_id = binary.lhs; let rhs_id = binary.rhs; let lhs = self.lookup_numeric(lhs_id, "binary op lhs")?; @@ -1066,7 +1097,7 @@ impl Interpreter<'_> { } // Disable this instruction if it is side-effectful and side effects are disabled. - if !self.side_effects_enabled() && binary.requires_acir_gen_predicate(self.dfg()) { + if !side_effects_enabled { let zero = NumericValue::zero(lhs.get_type()); return Ok(Value::Numeric(zero)); } diff --git a/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs b/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs index f5cda1eb7e5..6b52388fbdc 100644 --- a/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs +++ b/compiler/noirc_evaluator/src/ssa/interpreter/tests/instructions.rs @@ -6,7 +6,10 @@ use noirc_frontend::Shared; use crate::ssa::{ interpreter::{ InterpreterError, NumericValue, Value, - tests::{expect_value, expect_values, expect_values_with_args, from_constant}, + tests::{ + expect_value, expect_value_with_args, expect_values, expect_values_with_args, + from_constant, + }, value::ReferenceValue, }, ir::{ @@ -627,8 +630,8 @@ fn constrain() { } #[test] -fn constrain_disabled_by_enable_side_effects() { - executes_with_no_errors( +fn constrain_not_disabled_by_enable_side_effects() { + expect_error( " acir(inline) fn main f0 { b0(): @@ -640,34 +643,33 @@ fn constrain_disabled_by_enable_side_effects() { ); } -// SSA Parser does not yet parse ConstrainNotEqual -// #[test] -// fn constrain_not_equal() { -// executes_with_no_errors( -// " -// acir(inline) fn main f0 { -// b0(): -// v0 = eq u8 3, u8 4 -// constrain v0 != u1 1 -// return -// } -// ", -// ); -// } -// -// #[test] -// fn constrain_not_equal_disabled_by_enable_side_effects() { -// executes_with_no_errors( -// " -// acir(inline) fn main f0 { -// b0(): -// enable_side_effects u1 0 -// constrain u1 1 != u1 1 -// return -// } -// ", -// ); -// } +#[test] +fn constrain_not_equal() { + executes_with_no_errors( + " + acir(inline) fn main f0 { + b0(): + v0 = eq u8 3, u8 4 + constrain v0 != u1 1 + return + } + ", + ); +} + +#[test] +fn constrain_not_equal_not_disabled_by_enable_side_effects() { + expect_error( + " + acir(inline) fn main f0 { + b0(): + enable_side_effects u1 0 + constrain u1 1 != u1 1 + return + } + ", + ); +} #[test] fn range_check() { @@ -697,8 +699,8 @@ fn range_check_fail() { } #[test] -fn range_check_disabled_by_enable_side_effects() { - executes_with_no_errors( +fn range_check_not_disabled_by_enable_side_effects() { + expect_error( " acir(inline) fn main f0 { b0(): @@ -846,7 +848,7 @@ fn array_get_with_offset() { } #[test] -fn array_get_disabled_by_enable_side_effects() { +fn array_get_not_disabled_by_enable_side_effects_if_index_is_known_to_be_safe() { let value = expect_value( r#" acir(inline) fn main f0 { @@ -858,6 +860,23 @@ fn array_get_disabled_by_enable_side_effects() { } "#, ); + assert_eq!(value, from_constant(2_u32.into(), NumericType::NativeField)); +} + +#[test] +fn array_get_disabled_by_enable_side_effects_if_index_is_not_known_to_be_safe() { + let value = expect_value_with_args( + r#" + acir(inline) fn main f0 { + b0(v2: u32): + enable_side_effects u1 0 + v0 = make_array [Field 1, Field 2] : [Field; 2] + v1 = array_get v0, index v2 -> Field + return v1 + } + "#, + vec![Value::Numeric(NumericValue::U32(1))], + ); assert_eq!(value, from_constant(0_u32.into(), NumericType::NativeField)); }