diff --git a/compiler/noirc_evaluator/src/acir/acir_context/brillig_call.rs b/compiler/noirc_evaluator/src/acir/acir_context/brillig_call.rs index 26aee386a77..76cf4a3b43a 100644 --- a/compiler/noirc_evaluator/src/acir/acir_context/brillig_call.rs +++ b/compiler/noirc_evaluator/src/acir/acir_context/brillig_call.rs @@ -13,10 +13,32 @@ use iter_extended::{try_vecmap, vecmap}; use crate::brillig::brillig_ir::artifact::GeneratedBrillig; use crate::errors::{InternalError, RuntimeError}; -use super::generated_acir::BrilligStdlibFunc; +use super::generated_acir::{BrilligStdlibFunc, PLACEHOLDER_BRILLIG_INDEX}; use super::{AcirContext, AcirDynamicArray, AcirType, AcirValue, AcirVar}; impl> AcirContext { + /// Generates a brillig call to a handwritten section of brillig bytecode. + pub(crate) fn stdlib_brillig_call( + &mut self, + predicate: AcirVar, + brillig_stdlib_func: BrilligStdlibFunc, + stdlib_func_bytecode: &GeneratedBrillig, + inputs: Vec, + outputs: Vec, + attempt_execution: bool, + ) -> Result, RuntimeError> { + self.brillig_call( + predicate, + stdlib_func_bytecode, + inputs, + outputs, + attempt_execution, + false, + PLACEHOLDER_BRILLIG_INDEX, + Some(brillig_stdlib_func), + ) + } + #[allow(clippy::too_many_arguments)] pub(crate) fn brillig_call( &mut self, diff --git a/compiler/noirc_evaluator/src/acir/acir_context/generated_acir/brillig_directive.rs b/compiler/noirc_evaluator/src/acir/acir_context/generated_acir/brillig_directive.rs index 5fab9e34523..b05cc68e2e2 100644 --- a/compiler/noirc_evaluator/src/acir/acir_context/generated_acir/brillig_directive.rs +++ b/compiler/noirc_evaluator/src/acir/acir_context/generated_acir/brillig_directive.rs @@ -4,10 +4,51 @@ use acvm::acir::{ BinaryFieldOp, BinaryIntOp, BitSize, HeapVector, IntegerBitSize, MemoryAddress, Opcode as BrilligOpcode, }, + circuit::brillig::BrilligFunctionId, }; use crate::brillig::brillig_ir::artifact::GeneratedBrillig; +/// Brillig calls such as for the Brillig std lib are resolved only after code generation is finished. +/// This index should be used when adding a Brillig call during code generation. +/// Code generation should then keep track of that unresolved call opcode which will be resolved with the +/// correct function index after code generation. +pub(crate) const PLACEHOLDER_BRILLIG_INDEX: BrilligFunctionId = BrilligFunctionId(0); + +#[derive(Debug, Clone)] +pub(crate) struct BrilligStdLib { + pub(crate) invert: GeneratedBrillig, + pub(crate) quotient: GeneratedBrillig, + pub(crate) to_le_bytes: GeneratedBrillig, +} + +impl Default for BrilligStdLib { + fn default() -> Self { + Self { + invert: directive_invert(), + quotient: directive_quotient(), + to_le_bytes: directive_to_radix(), + } + } +} + +impl BrilligStdLib { + pub(crate) fn get_code(&self, func: BrilligStdlibFunc) -> &GeneratedBrillig { + match func { + BrilligStdlibFunc::Inverse => &self.invert, + BrilligStdlibFunc::Quotient => &self.quotient, + BrilligStdlibFunc::ToLeBytes => &self.to_le_bytes, + } + } +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub(crate) enum BrilligStdlibFunc { + Inverse, + Quotient, + ToLeBytes, +} + /// Generates brillig bytecode which computes the inverse of its input if not null, and zero else. pub(crate) fn directive_invert() -> GeneratedBrillig { // We generate the following code: diff --git a/compiler/noirc_evaluator/src/acir/acir_context/generated_acir/mod.rs b/compiler/noirc_evaluator/src/acir/acir_context/generated_acir/mod.rs index b5f8374f77e..c44440468b8 100644 --- a/compiler/noirc_evaluator/src/acir/acir_context/generated_acir/mod.rs +++ b/compiler/noirc_evaluator/src/acir/acir_context/generated_acir/mod.rs @@ -31,11 +31,7 @@ use num_bigint::BigUint; mod brillig_directive; -/// Brillig calls such as for the Brillig std lib are resolved only after code generation is finished. -/// This index should be used when adding a Brillig call during code generation. -/// Code generation should then keep track of that unresolved call opcode which will be resolved with the -/// correct function index after code generation. -pub(super) const PLACEHOLDER_BRILLIG_INDEX: BrilligFunctionId = BrilligFunctionId(0); +pub(crate) use brillig_directive::{BrilligStdLib, BrilligStdlibFunc, PLACEHOLDER_BRILLIG_INDEX}; #[derive(Debug, Default)] /// The output of the Acir-gen pass, which should only be produced for entry point Acir functions @@ -97,23 +93,6 @@ pub(crate) type BrilligOpcodeToLocationsMap = BTreeMap; -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -pub(crate) enum BrilligStdlibFunc { - Inverse, - Quotient, - ToLeBytes, -} - -impl BrilligStdlibFunc { - pub(crate) fn get_generated_brillig(&self) -> GeneratedBrillig { - match self { - BrilligStdlibFunc::Inverse => brillig_directive::directive_invert(), - BrilligStdlibFunc::Quotient => brillig_directive::directive_quotient(), - BrilligStdlibFunc::ToLeBytes => brillig_directive::directive_to_radix(), - } - } -} - impl GeneratedAcir { /// Returns the current witness index. pub(crate) fn current_witness_index(&self) -> Witness { diff --git a/compiler/noirc_evaluator/src/acir/acir_context/mod.rs b/compiler/noirc_evaluator/src/acir/acir_context/mod.rs index 0516fb6f2dd..0d75646a197 100644 --- a/compiler/noirc_evaluator/src/acir/acir_context/mod.rs +++ b/compiler/noirc_evaluator/src/acir/acir_context/mod.rs @@ -37,9 +37,8 @@ use super::{ types::{AcirType, AcirVar}, }; use big_int::BigIntContext; -use generated_acir::PLACEHOLDER_BRILLIG_INDEX; -pub(crate) use generated_acir::{BrilligStdlibFunc, GeneratedAcir}; +pub(crate) use generated_acir::{BrilligStdLib, BrilligStdlibFunc, GeneratedAcir}; #[derive(Debug, Default)] /// Context object which holds the relationship between @@ -47,6 +46,7 @@ pub(crate) use generated_acir::{BrilligStdlibFunc, GeneratedAcir}; /// which are placed into ACIR. pub(crate) struct AcirContext> { pub(super) blackbox_solver: B, + brillig_stdlib: BrilligStdLib, vars: HashMap>, @@ -70,6 +70,19 @@ pub(crate) struct AcirContext> { } impl> AcirContext { + pub(super) fn new(brillig_stdlib: BrilligStdLib, blackbox_solver: B) -> Self { + AcirContext { + brillig_stdlib, + blackbox_solver, + vars: Default::default(), + constant_witnesses: Default::default(), + acir_ir: Default::default(), + big_int_ctx: Default::default(), + expression_width: Default::default(), + warnings: Default::default(), + } + } + pub(crate) fn set_expression_width(&mut self, expression_width: ExpressionWidth) { self.expression_width = expression_width; } @@ -273,18 +286,13 @@ impl> AcirContext { return Ok(inverted_var); } - // Compute the inverse with brillig code - let inverse_code = BrilligStdlibFunc::Inverse.get_generated_brillig(); - - let results = self.brillig_call( + let results = self.stdlib_brillig_call( predicate, - &inverse_code, + BrilligStdlibFunc::Inverse, + &self.brillig_stdlib.get_code(BrilligStdlibFunc::Inverse).clone(), vec![AcirValue::Var(var, AcirType::field())], vec![AcirType::field()], true, - false, - PLACEHOLDER_BRILLIG_INDEX, - Some(BrilligStdlibFunc::Inverse), )?; let inverted_var = Self::expect_one_var(results); @@ -799,18 +807,16 @@ impl> AcirContext { }; let [q_value, r_value]: [AcirValue; 2] = self - .brillig_call( + .stdlib_brillig_call( predicate, - &BrilligStdlibFunc::Quotient.get_generated_brillig(), + BrilligStdlibFunc::Quotient, + &self.brillig_stdlib.get_code(BrilligStdlibFunc::Quotient).clone(), vec![ AcirValue::Var(lhs, AcirType::unsigned(bit_size)), AcirValue::Var(rhs, AcirType::unsigned(bit_size)), ], vec![AcirType::unsigned(max_q_bits), AcirType::unsigned(max_rhs_bits)], true, - false, - PLACEHOLDER_BRILLIG_INDEX, - Some(BrilligStdlibFunc::Quotient), )? .try_into() .expect("quotient only returns two values"); diff --git a/compiler/noirc_evaluator/src/acir/mod.rs b/compiler/noirc_evaluator/src/acir/mod.rs index 43fc3da46f8..a44a1319352 100644 --- a/compiler/noirc_evaluator/src/acir/mod.rs +++ b/compiler/noirc_evaluator/src/acir/mod.rs @@ -55,11 +55,13 @@ use crate::ssa::{ ssa_gen::Ssa, }; pub(crate) use acir_context::GeneratedAcir; -use acir_context::{AcirContext, BrilligStdlibFunc, power_of_two}; +use acir_context::{AcirContext, BrilligStdLib, BrilligStdlibFunc, power_of_two}; use types::{AcirType, AcirVar}; #[derive(Default)] -struct SharedContext { +struct SharedContext { + brillig_stdlib: BrilligStdLib, + /// Final list of Brillig functions which will be part of the final program /// This is shared across `Context` structs as we want one list of Brillig /// functions across all ACIR artifacts @@ -122,14 +124,14 @@ impl SharedContext { { self.add_call_to_resolve(func_id, (opcode_location, generated_pointer)); } else { - let code = brillig_stdlib_func.get_generated_brillig(); + let code = self.brillig_stdlib.get_code(*brillig_stdlib_func); let generated_pointer = self.new_generated_pointer(); self.insert_generated_brillig_stdlib( *brillig_stdlib_func, generated_pointer, func_id, opcode_location, - code, + code.clone(), ); } } @@ -224,9 +226,10 @@ impl<'a> Context<'a> { shared_context: &'a mut SharedContext, expression_width: ExpressionWidth, brillig: &'a Brillig, + brillig_stdlib: BrilligStdLib, brillig_options: &'a BrilligOptions, ) -> Context<'a> { - let mut acir_context = AcirContext::default(); + let mut acir_context = AcirContext::new(brillig_stdlib, Bn254BlackBoxSolver::default()); acir_context.set_expression_width(expression_width); let current_side_effects_enabled_var = acir_context.add_constant(FieldElement::one()); @@ -2760,7 +2763,7 @@ mod test { }, circuit::{ ExpressionWidth, Opcode, OpcodeLocation, - brillig::{BrilligBytecode, BrilligFunctionId}, + brillig::BrilligFunctionId, opcodes::{AcirFunctionId, BlackBoxFuncCall}, }, native_types::{Witness, WitnessMap}, @@ -2773,8 +2776,8 @@ mod test { use std::collections::BTreeMap; use crate::{ - acir::BrilligStdlibFunc, - brillig::{Brillig, BrilligOptions}, + acir::{BrilligStdlibFunc, acir_context::BrilligStdLib, ssa::codegen_acir}, + brillig::{Brillig, BrilligOptions, brillig_ir::artifact::GeneratedBrillig}, ssa::{ function_builder::FunctionBuilder, ir::{ @@ -3595,16 +3598,6 @@ mod test { }"; let ssa = Ssa::from_str(src).unwrap(); - let (acir_functions, mut brillig_functions, _, _) = ssa - .into_acir(&Brillig::default(), &BrilligOptions::default(), ExpressionWidth::default()) - .expect("Should compile manually written SSA into ACIR"); - - assert_eq!(acir_functions.len(), 1); - // [`directive_quotient`, `directive_invert`] - assert_eq!(brillig_functions.len(), 2); - - let main = &acir_functions[0]; - // Here we're attempting to perform a truncation of a `Field` type into 32 bits. We then do a euclidean // division `a/b` with `a` and `b` taking the values: // @@ -3635,8 +3628,8 @@ mod test { // This brillig function replaces the standard implementation of `directive_quotient` with // an implementation which returns `(malicious_q, malicious_r)`. - let malicious_quotient = BrilligBytecode { - bytecode: vec![ + let malicious_quotient = GeneratedBrillig { + byte_code: vec![ BrilligOpcode::Const { destination: MemoryAddress::direct(10), bit_size: BitSize::Integer(IntegerBitSize::U32), @@ -3664,15 +3657,34 @@ mod test { }, }, ], + name: "malicious_directive_quotient".to_string(), + ..Default::default() }; - let malicious_brillig = [malicious_quotient, brillig_functions.remove(1)]; + + let malicious_brillig_stdlib = + BrilligStdLib { quotient: malicious_quotient, ..BrilligStdLib::default() }; + + let (acir_functions, brillig_functions, _, _) = codegen_acir( + ssa, + &Brillig::default(), + malicious_brillig_stdlib, + &BrilligOptions::default(), + ExpressionWidth::default(), + ) + .expect("Should compile manually written SSA into ACIR"); + + assert_eq!(acir_functions.len(), 1); + // [`malicious_directive_quotient`, `directive_invert`] + assert_eq!(brillig_functions.len(), 2); + + let main = &acir_functions[0]; let initial_witness = WitnessMap::from(BTreeMap::from([(Witness(0), input)])); let mut acvm = ACVM::new( &StubbedBlackBoxSolver(true), main.opcodes(), initial_witness, - &malicious_brillig, + &brillig_functions, &[], ); diff --git a/compiler/noirc_evaluator/src/acir/ssa.rs b/compiler/noirc_evaluator/src/acir/ssa.rs index 2cefca44cc7..df000d44b7e 100644 --- a/compiler/noirc_evaluator/src/acir/ssa.rs +++ b/compiler/noirc_evaluator/src/acir/ssa.rs @@ -12,7 +12,7 @@ use crate::{ ssa::ssa_gen::Ssa, }; -use super::{Context, GeneratedAcir, SharedContext}; +use super::{Context, GeneratedAcir, SharedContext, acir_context::BrilligStdLib}; pub(crate) type Artifacts = ( Vec>, @@ -29,55 +29,69 @@ impl Ssa { brillig_options: &BrilligOptions, expression_width: ExpressionWidth, ) -> Result { - let mut acirs = Vec::new(); - // TODO: can we parallelize this? - let mut shared_context = SharedContext::default(); + codegen_acir(self, brillig, BrilligStdLib::default(), brillig_options, expression_width) + } +} - for function in self.functions.values() { - let context = - Context::new(&mut shared_context, expression_width, brillig, brillig_options); +pub(super) fn codegen_acir( + ssa: Ssa, + brillig: &Brillig, + brillig_stdlib: BrilligStdLib, + brillig_options: &BrilligOptions, + expression_width: ExpressionWidth, +) -> Result { + let mut acirs = Vec::new(); + // TODO: can we parallelize this? + let mut shared_context = + SharedContext { brillig_stdlib: brillig_stdlib.clone(), ..SharedContext::default() }; - if let Some(mut generated_acir) = context.convert_ssa_function(&self, function)? { - // We want to be able to insert Brillig stdlib functions anywhere during the ACIR generation process (e.g. such as on the `GeneratedAcir`). - // As we don't want a reference to the `SharedContext` on the generated ACIR itself, - // we instead store the opcode location at which a Brillig call to a std lib function occurred. - // We then defer resolving the function IDs of those Brillig functions to when we have generated Brillig - // for all normal Brillig calls. - for (opcode_location, brillig_stdlib_func) in - &generated_acir.brillig_stdlib_func_locations - { - shared_context.generate_brillig_calls_to_resolve( - brillig_stdlib_func, - function.id(), - *opcode_location, - ); - } + for function in ssa.functions.values() { + let context = Context::new( + &mut shared_context, + expression_width, + brillig, + brillig_stdlib.clone(), + brillig_options, + ); - // Fetch the Brillig stdlib calls to resolve for this function - if let Some(calls_to_resolve) = - shared_context.brillig_stdlib_calls_to_resolve.get(&function.id()) - { - // Resolve the Brillig stdlib calls - // We have to do a separate loop as the generated ACIR cannot be borrowed as mutable after an immutable borrow - for (opcode_location, brillig_function_pointer) in calls_to_resolve { - generated_acir.resolve_brillig_stdlib_call( - *opcode_location, - *brillig_function_pointer, - ); - } - } + if let Some(mut generated_acir) = context.convert_ssa_function(&ssa, function)? { + // We want to be able to insert Brillig stdlib functions anywhere during the ACIR generation process (e.g. such as on the `GeneratedAcir`). + // As we don't want a reference to the `SharedContext` on the generated ACIR itself, + // we instead store the opcode location at which a Brillig call to a std lib function occurred. + // We then defer resolving the function IDs of those Brillig functions to when we have generated Brillig + // for all normal Brillig calls. + for (opcode_location, brillig_stdlib_func) in + &generated_acir.brillig_stdlib_func_locations + { + shared_context.generate_brillig_calls_to_resolve( + brillig_stdlib_func, + function.id(), + *opcode_location, + ); + } - generated_acir.name = function.name().to_owned(); - acirs.push(generated_acir); + // Fetch the Brillig stdlib calls to resolve for this function + if let Some(calls_to_resolve) = + shared_context.brillig_stdlib_calls_to_resolve.get(&function.id()) + { + // Resolve the Brillig stdlib calls + // We have to do a separate loop as the generated ACIR cannot be borrowed as mutable after an immutable borrow + for (opcode_location, brillig_function_pointer) in calls_to_resolve { + generated_acir + .resolve_brillig_stdlib_call(*opcode_location, *brillig_function_pointer); + } } + + generated_acir.name = function.name().to_owned(); + acirs.push(generated_acir); } + } - let (brillig_bytecode, brillig_names) = shared_context - .generated_brillig - .into_iter() - .map(|brillig| (BrilligBytecode { bytecode: brillig.byte_code }, brillig.name)) - .unzip(); + let (brillig_bytecode, brillig_names) = shared_context + .generated_brillig + .into_iter() + .map(|brillig| (BrilligBytecode { bytecode: brillig.byte_code }, brillig.name)) + .unzip(); - Ok((acirs, brillig_bytecode, brillig_names, self.error_selector_to_type)) - } + Ok((acirs, brillig_bytecode, brillig_names, ssa.error_selector_to_type)) } diff --git a/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs b/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs index 42052c09230..b64c67886ab 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_ir/artifact.rs @@ -21,7 +21,7 @@ pub(crate) enum BrilligParameter { /// The result of compiling and linking brillig artifacts. /// This is ready to run bytecode with attached metadata. -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub(crate) struct GeneratedBrillig { pub(crate) byte_code: Vec>, pub(crate) locations: BTreeMap,