diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index 4a0f9f798ff..c4b4301121b 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -255,7 +255,7 @@ impl AcirContext { } /// Converts an [`AcirVar`] to a [`Witness`] - fn var_to_witness(&mut self, var: AcirVar) -> Result { + pub(crate) fn var_to_witness(&mut self, var: AcirVar) -> Result { let expression = self.var_to_expression(var)?; let witness = if let Some(constant) = expression.to_const() { // Check if a witness has been assigned this value already, if so reuse it. @@ -1017,15 +1017,6 @@ impl AcirContext { Ok(remainder) } - /// Converts the `AcirVar` to a `Witness` if it hasn't been already, and appends it to the - /// `GeneratedAcir`'s return witnesses. - pub(crate) fn return_var(&mut self, acir_var: AcirVar) -> Result<(), InternalError> { - let return_var = self.get_or_create_witness_var(acir_var)?; - let witness = self.var_to_witness(return_var)?; - self.acir_ir.push_return_witness(witness); - Ok(()) - } - /// Constrains the `AcirVar` variable to be of type `NumericType`. pub(crate) fn range_constrain_var( &mut self, @@ -1528,9 +1519,11 @@ impl AcirContext { pub(crate) fn finish( mut self, inputs: Vec, + return_values: Vec, warnings: Vec, ) -> GeneratedAcir { self.acir_ir.input_witnesses = inputs; + self.acir_ir.return_witnesses = return_values; self.acir_ir.warnings = warnings; self.acir_ir } diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs index 6c79c0a228d..9a09e7c06ee 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs @@ -45,9 +45,6 @@ pub(crate) struct GeneratedAcir { opcodes: Vec>, /// All witness indices that comprise the final return value of the program - /// - /// Note: This may contain repeated indices, which is necessary for later mapping into the - /// abi's return type. pub(crate) return_witnesses: Vec, /// All witness indices which are inputs to the main function @@ -164,11 +161,6 @@ impl GeneratedAcir { fresh_witness } - - /// Adds a witness index to the program's return witnesses. - pub(crate) fn push_return_witness(&mut self, witness: Witness) { - self.return_witnesses.push(witness); - } } impl GeneratedAcir { diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs index 13677506d0b..6d7c5e570c1 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -36,11 +36,7 @@ use acvm::acir::circuit::brillig::BrilligBytecode; use acvm::acir::circuit::{AssertionPayload, ErrorSelector, OpcodeLocation}; use acvm::acir::native_types::Witness; use acvm::acir::BlackBoxFunc; -use acvm::{ - acir::AcirField, - acir::{circuit::opcodes::BlockId, native_types::Expression}, - FieldElement, -}; +use acvm::{acir::circuit::opcodes::BlockId, acir::AcirField, FieldElement}; use fxhash::FxHashMap as HashMap; use im::Vector; use iter_extended::{try_vecmap, vecmap}; @@ -330,38 +326,10 @@ impl Ssa { bytecode: brillig.byte_code, }); - let runtime_types = self.functions.values().map(|function| function.runtime()); - for (acir, runtime_type) in acirs.iter_mut().zip(runtime_types) { - if matches!(runtime_type, RuntimeType::Acir(_)) { - generate_distinct_return_witnesses(acir); - } - } - Ok((acirs, brillig, self.error_selector_to_type)) } } -fn generate_distinct_return_witnesses(acir: &mut GeneratedAcir) { - // Create a witness for each return witness we have to guarantee that the return witnesses match the standard - // layout for serializing those types as if they were being passed as inputs. - // - // This is required for recursion as otherwise in situations where we cannot make use of the program's ABI - // (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're - // working with rather than following the standard ABI encoding rules. - // - // TODO: We're being conservative here by generating a new witness for every expression. - // This means that we're likely to get a number of constraints which are just renumbering witnesses. - // This can be tackled by: - // - Tracking the last assigned public input witness and only renumbering a witness if it is below this value. - // - Modifying existing constraints to rearrange their outputs so they are suitable - // - See: https://github.com/noir-lang/noir/pull/4467 - let distinct_return_witness = vecmap(acir.return_witnesses.clone(), |return_witness| { - acir.create_witness_for_expression(&Expression::from(return_witness)) - }); - - acir.return_witnesses = distinct_return_witness; -} - impl<'a> Context<'a> { fn new(shared_context: &'a mut SharedContext) -> Context<'a> { let mut acir_context = AcirContext::default(); @@ -422,6 +390,25 @@ impl<'a> Context<'a> { let dfg = &main_func.dfg; let entry_block = &dfg[main_func.entry_block()]; let input_witness = self.convert_ssa_block_params(entry_block.parameters(), dfg)?; + let num_return_witnesses = + self.get_num_return_witnesses(entry_block.unwrap_terminator(), dfg); + + // Create a witness for each return witness we have to guarantee that the return witnesses match the standard + // layout for serializing those types as if they were being passed as inputs. + // + // This is required for recursion as otherwise in situations where we cannot make use of the program's ABI + // (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're + // working with rather than following the standard ABI encoding rules. + // + // We allocate these witnesses now before performing ACIR gen for the rest of the program as the location of + // the function's return values can then be determined through knowledge of its ABI alone. + let return_witness_vars = + vecmap(0..num_return_witnesses, |_| self.acir_context.add_variable()); + + let return_witnesses = vecmap(&return_witness_vars, |return_var| { + let expr = self.acir_context.var_to_expression(*return_var).unwrap(); + expr.to_witness().expect("return vars should be witnesses") + }); self.data_bus = dfg.data_bus.to_owned(); let mut warnings = Vec::new(); @@ -429,8 +416,19 @@ impl<'a> Context<'a> { warnings.extend(self.convert_ssa_instruction(*instruction_id, dfg, ssa, brillig)?); } - warnings.extend(self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?); - Ok(self.acir_context.finish(input_witness, warnings)) + let (return_vars, return_warnings) = + self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?; + + // TODO: This is a naive method of assigning the return values to their witnesses as + // we're likely to get a number of constraints which are asserting one witness to be equal to another. + // + // We should search through the program and relabel these witnesses so we can remove this constraint. + for (witness_var, return_var) in return_witness_vars.iter().zip(return_vars) { + self.acir_context.assert_eq_var(*witness_var, return_var, None)?; + } + + warnings.extend(return_warnings); + Ok(self.acir_context.finish(input_witness, return_witnesses, warnings)) } fn convert_brillig_main( @@ -468,17 +466,13 @@ impl<'a> Context<'a> { )?; self.shared_context.insert_generated_brillig(main_func.id(), arguments, 0, code); - let output_vars: Vec<_> = output_values + let return_witnesses: Vec = output_values .iter() .flat_map(|value| value.clone().flatten()) - .map(|value| value.0) - .collect(); + .map(|(value, _)| self.acir_context.var_to_witness(value)) + .collect::>()?; - for acir_var in output_vars { - self.acir_context.return_var(acir_var)?; - } - - let generated_acir = self.acir_context.finish(witness_inputs, Vec::new()); + let generated_acir = self.acir_context.finish(witness_inputs, return_witnesses, Vec::new()); assert_eq!( generated_acir.opcodes().len(), @@ -1724,12 +1718,39 @@ impl<'a> Context<'a> { self.define_result(dfg, instruction, AcirValue::Var(result, typ)); } + /// Converts an SSA terminator's return values into their ACIR representations + fn get_num_return_witnesses( + &mut self, + terminator: &TerminatorInstruction, + dfg: &DataFlowGraph, + ) -> usize { + let return_values = match terminator { + TerminatorInstruction::Return { return_values, .. } => return_values, + // TODO(https://github.com/noir-lang/noir/issues/4616): Enable recursion on foldable/non-inlined ACIR functions + _ => unreachable!("ICE: Program must have a singular return"), + }; + + return_values.iter().fold(0, |acc, value_id| { + let is_databus = self + .data_bus + .return_data + .map_or(false, |return_databus| dfg[*value_id] == dfg[return_databus]); + + if is_databus { + // We do not return value for the data bus. + acc + } else { + acc + dfg.type_of_value(*value_id).flattened_size() + } + }) + } + /// Converts an SSA terminator's return values into their ACIR representations fn convert_ssa_return( &mut self, terminator: &TerminatorInstruction, dfg: &DataFlowGraph, - ) -> Result, RuntimeError> { + ) -> Result<(Vec, Vec), RuntimeError> { let (return_values, call_stack) = match terminator { TerminatorInstruction::Return { return_values, call_stack } => { (return_values, call_stack.clone()) @@ -1739,6 +1760,7 @@ impl<'a> Context<'a> { }; let mut has_constant_return = false; + let mut return_vars: Vec = Vec::new(); for value_id in return_values { let is_databus = self .data_bus @@ -1759,7 +1781,7 @@ impl<'a> Context<'a> { dfg, )?; } else { - self.acir_context.return_var(acir_var)?; + return_vars.push(acir_var); } } } @@ -1770,7 +1792,7 @@ impl<'a> Context<'a> { Vec::new() }; - Ok(warnings) + Ok((return_vars, warnings)) } /// Gets the cached `AcirVar` that was converted from the corresponding `ValueId`. If it does @@ -3079,8 +3101,8 @@ mod test { check_call_opcode( &func_with_nested_call_opcodes[1], 2, - vec![Witness(2), Witness(1)], - vec![Witness(3)], + vec![Witness(3), Witness(1)], + vec![Witness(4)], ); } @@ -3100,13 +3122,13 @@ mod test { for (expected_input, input) in expected_inputs.iter().zip(inputs) { assert_eq!( expected_input, input, - "Expected witness {expected_input:?} but got {input:?}" + "Expected input witness {expected_input:?} but got {input:?}" ); } for (expected_output, output) in expected_outputs.iter().zip(outputs) { assert_eq!( expected_output, output, - "Expected witness {expected_output:?} but got {output:?}" + "Expected output witness {expected_output:?} but got {output:?}" ); } }