diff --git a/compiler/noirc_driver/src/abi_gen.rs b/compiler/noirc_driver/src/abi_gen.rs index d0b33945f40..6625eba5a0a 100644 --- a/compiler/noirc_driver/src/abi_gen.rs +++ b/compiler/noirc_driver/src/abi_gen.rs @@ -116,7 +116,7 @@ fn to_abi_visibility(value: Visibility) -> AbiVisibility { match value { Visibility::Public => AbiVisibility::Public, Visibility::Private => AbiVisibility::Private, - Visibility::DataBus => AbiVisibility::DataBus, + Visibility::CallData(_) | Visibility::ReturnData => AbiVisibility::DataBus, } } diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs index dc0e026b3d1..d7dd5e5dbce 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -430,6 +430,12 @@ impl<'a> Context<'a> { let (return_vars, return_warnings) = self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?; + let call_data_arrays: Vec = + self.data_bus.call_data.iter().map(|cd| cd.array_id).collect(); + for call_data_array in call_data_arrays { + self.ensure_array_is_initialized(call_data_array, dfg)?; + } + // TODO: This is a naive method of assigning the return values to their witnesses as // we're likely to get a number of constraints which are asserting one witness to be equal to another. // @@ -1263,20 +1269,23 @@ impl<'a> Context<'a> { let res_typ = dfg.type_of_value(results[0]); // Get operations to call-data parameters are replaced by a get to the call-data-bus array - if let Some(call_data) = self.data_bus.call_data { - if self.data_bus.call_data_map.contains_key(&array) { - // TODO: the block_id of call-data must be notified to the backend - // TODO: should we do the same for return-data? - let type_size = res_typ.flattened_size(); - let type_size = - self.acir_context.add_constant(FieldElement::from(type_size as i128)); - let offset = self.acir_context.mul_var(var_index, type_size)?; - let bus_index = self - .acir_context - .add_constant(FieldElement::from(self.data_bus.call_data_map[&array] as i128)); - let new_index = self.acir_context.add_var(offset, bus_index)?; - return self.array_get(instruction, call_data, new_index, dfg, index_side_effect); - } + if let Some(call_data) = + self.data_bus.call_data.iter().find(|cd| cd.index_map.contains_key(&array)) + { + let type_size = res_typ.flattened_size(); + let type_size = self.acir_context.add_constant(FieldElement::from(type_size as i128)); + let offset = self.acir_context.mul_var(var_index, type_size)?; + let bus_index = self + .acir_context + .add_constant(FieldElement::from(call_data.index_map[&array] as i128)); + let new_index = self.acir_context.add_var(offset, bus_index)?; + return self.array_get( + instruction, + call_data.array_id, + new_index, + dfg, + index_side_effect, + ); } // Compiler sanity check @@ -1707,17 +1716,20 @@ impl<'a> Context<'a> { len: usize, value: Option, ) -> Result<(), InternalError> { - let databus = if self.data_bus.call_data.is_some() - && self.block_id(&self.data_bus.call_data.unwrap()) == array - { - BlockType::CallData - } else if self.data_bus.return_data.is_some() + let mut databus = BlockType::Memory; + if self.data_bus.return_data.is_some() && self.block_id(&self.data_bus.return_data.unwrap()) == array { - BlockType::ReturnData - } else { - BlockType::Memory - }; + databus = BlockType::ReturnData; + } + for array_id in self.data_bus.call_data_array() { + if self.block_id(&array_id) == array { + assert!(databus == BlockType::Memory); + databus = BlockType::CallData; + break; + } + } + self.acir_context.initialize_array(array, len, value, databus)?; self.initialized_arrays.insert(array); Ok(()) diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs index 5f0660f5a79..50964e9161b 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/data_bus.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::rc::Rc; use crate::ssa::ir::{types::Type, value::ValueId}; @@ -8,6 +9,12 @@ use noirc_frontend::hir_def::function::FunctionSignature; use super::FunctionBuilder; +#[derive(Clone)] +pub(crate) enum DatabusVisibility { + None, + CallData(u32), + ReturnData, +} /// Used to create a data bus, which is an array of private inputs /// replacing public inputs pub(crate) struct DataBusBuilder { @@ -27,15 +34,16 @@ impl DataBusBuilder { } } - /// Generates a boolean vector telling which (ssa) parameter from the given function signature + /// Generates a vector telling which (ssa) parameters from the given function signature /// are tagged with databus visibility - pub(crate) fn is_databus(main_signature: &FunctionSignature) -> Vec { + pub(crate) fn is_databus(main_signature: &FunctionSignature) -> Vec { let mut params_is_databus = Vec::new(); for param in &main_signature.0 { let is_databus = match param.2 { - ast::Visibility::Public | ast::Visibility::Private => false, - ast::Visibility::DataBus => true, + ast::Visibility::Public | ast::Visibility::Private => DatabusVisibility::None, + ast::Visibility::CallData(id) => DatabusVisibility::CallData(id), + ast::Visibility::ReturnData => DatabusVisibility::ReturnData, }; let len = param.1.field_count() as usize; params_is_databus.extend(vec![is_databus; len]); @@ -44,34 +52,51 @@ impl DataBusBuilder { } } +#[derive(Clone, Debug)] +pub(crate) struct CallData { + pub(crate) array_id: ValueId, + pub(crate) index_map: HashMap, +} + #[derive(Clone, Default, Debug)] pub(crate) struct DataBus { - pub(crate) call_data: Option, - pub(crate) call_data_map: HashMap, + pub(crate) call_data: Vec, pub(crate) return_data: Option, } impl DataBus { /// Updates the databus values with the provided function pub(crate) fn map_values(&self, mut f: impl FnMut(ValueId) -> ValueId) -> DataBus { - let mut call_data_map = HashMap::default(); - for (k, v) in self.call_data_map.iter() { - call_data_map.insert(f(*k), *v); - } - DataBus { - call_data: self.call_data.map(&mut f), - call_data_map, - return_data: self.return_data.map(&mut f), - } + let call_data = self + .call_data + .iter() + .map(|cd| { + let mut call_data_map = HashMap::default(); + for (k, v) in cd.index_map.iter() { + call_data_map.insert(f(*k), *v); + } + CallData { array_id: f(cd.array_id), index_map: call_data_map } + }) + .collect(); + DataBus { call_data, return_data: self.return_data.map(&mut f) } } + pub(crate) fn call_data_array(&self) -> Vec { + self.call_data.iter().map(|cd| cd.array_id).collect() + } /// Construct a databus from call_data and return_data data bus builders - pub(crate) fn get_data_bus(call_data: DataBusBuilder, return_data: DataBusBuilder) -> DataBus { - DataBus { - call_data: call_data.databus, - call_data_map: call_data.map, - return_data: return_data.databus, + pub(crate) fn get_data_bus( + call_data: Vec, + return_data: DataBusBuilder, + ) -> DataBus { + let mut call_data_args = Vec::new(); + for call_data_item in call_data { + if let Some(array_id) = call_data_item.databus { + call_data_args.push(CallData { array_id, index_map: call_data_item.map }); + } } + + DataBus { call_data: call_data_args, return_data: return_data.databus } } } @@ -129,19 +154,36 @@ impl FunctionBuilder { } /// Generate the data bus for call-data, based on the parameters of the entry block - /// and a boolean vector telling which ones are call-data - pub(crate) fn call_data_bus(&mut self, is_params_databus: Vec) -> DataBusBuilder { + /// and a vector telling which ones are call-data + pub(crate) fn call_data_bus( + &mut self, + is_params_databus: Vec, + ) -> Vec { //filter parameters of the first block that have call-data visibility let first_block = self.current_function.entry_block(); let params = self.current_function.dfg[first_block].parameters(); - let mut databus_param = Vec::new(); - for (param, is_databus) in params.iter().zip(is_params_databus) { - if is_databus { - databus_param.push(param.to_owned()); + let mut databus_param: BTreeMap> = BTreeMap::new(); + for (param, databus_attribute) in params.iter().zip(is_params_databus) { + match databus_attribute { + DatabusVisibility::None | DatabusVisibility::ReturnData => continue, + DatabusVisibility::CallData(call_data_id) => { + if let std::collections::btree_map::Entry::Vacant(e) = + databus_param.entry(call_data_id) + { + e.insert(vec![param.to_owned()]); + } else { + databus_param.get_mut(&call_data_id).unwrap().push(param.to_owned()); + } + } } } - // create the call-data-bus from the filtered list - let call_data = DataBusBuilder::new(); - self.initialize_data_bus(&databus_param, call_data) + // create the call-data-bus from the filtered lists + let mut result = Vec::new(); + for id in databus_param.keys() { + let builder = DataBusBuilder::new(); + let call_databus = self.initialize_data_bus(&databus_param[id], builder); + result.push(call_databus); + } + result } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/die.rs b/compiler/noirc_evaluator/src/ssa/opt/die.rs index 42383680f44..24519d530ee 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -34,8 +34,8 @@ impl Ssa { /// of its instructions are needed elsewhere. fn dead_instruction_elimination(function: &mut Function) { let mut context = Context::default(); - if let Some(call_data) = function.dfg.data_bus.call_data { - context.mark_used_instruction_results(&function.dfg, call_data); + for call_data in &function.dfg.data_bus.call_data { + context.mark_used_instruction_results(&function.dfg, call_data.array_id); } let blocks = PostOrder::with_function(function); diff --git a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs index abd251b008f..468a8573307 100644 --- a/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/ssa_gen/mod.rs @@ -44,7 +44,7 @@ pub(crate) fn generate_ssa( // see which parameter has call_data/return_data attribute let is_databus = DataBusBuilder::is_databus(&program.main_function_signature); - let is_return_data = matches!(program.return_visibility, Visibility::DataBus); + let is_return_data = matches!(program.return_visibility, Visibility::ReturnData); let return_location = program.return_location; let context = SharedContext::new(program); diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index 038a13529d7..b97853033a8 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -390,7 +390,9 @@ pub enum Visibility { Private, /// DataBus is public input handled as private input. We use the fact that return values are properly computed by the program to avoid having them as public inputs /// it is useful for recursion and is handled by the proving system. - DataBus, + /// The u32 value is used to group inputs having the same value. + CallData(u32), + ReturnData, } impl std::fmt::Display for Visibility { @@ -398,7 +400,8 @@ impl std::fmt::Display for Visibility { match self { Self::Public => write!(f, "pub"), Self::Private => write!(f, "priv"), - Self::DataBus => write!(f, "databus"), + Self::CallData(id) => write!(f, "calldata{id}"), + Self::ReturnData => write!(f, "returndata"), } } } diff --git a/compiler/noirc_frontend/src/parser/errors.rs b/compiler/noirc_frontend/src/parser/errors.rs index c566489eb40..80adb01dc9a 100644 --- a/compiler/noirc_frontend/src/parser/errors.rs +++ b/compiler/noirc_frontend/src/parser/errors.rs @@ -46,6 +46,8 @@ pub enum ParserErrorReason { Lexer(LexerErrorKind), #[error("The only supported numeric generic types are `u1`, `u8`, `u16`, and `u32`")] ForbiddenNumericGenericType, + #[error("Invalid call data identifier, must be a number. E.g `call_data(0)`")] + InvalidCallDataIdentifier, } /// Represents a parsing error, or a parsing error in the making. diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index 3879f628eae..46d0ca5d206 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -46,6 +46,7 @@ use crate::ast::{ use crate::lexer::{lexer::from_spanned_token_result, Lexer}; use crate::parser::{force, ignore_then_commit, statement_recovery}; use crate::token::{Keyword, Token, TokenKind}; +use acvm::AcirField; use chumsky::prelude::*; use iter_extended::vecmap; @@ -645,19 +646,28 @@ where }) } +fn call_data() -> impl NoirParser { + keyword(Keyword::CallData).then(parenthesized(literal())).validate(|token, span, emit| { + match token { + (_, ExpressionKind::Literal(Literal::Integer(x, _))) => { + let id = x.to_u128() as u32; + Visibility::CallData(id) + } + _ => { + emit(ParserError::with_reason(ParserErrorReason::InvalidCallDataIdentifier, span)); + Visibility::CallData(0) + } + } + }) +} + fn optional_visibility() -> impl NoirParser { keyword(Keyword::Pub) - .or(keyword(Keyword::CallData)) - .or(keyword(Keyword::ReturnData)) + .map(|_| Visibility::Public) + .or(call_data()) + .or(keyword(Keyword::ReturnData).map(|_| Visibility::ReturnData)) .or_not() - .map(|opt| match opt { - Some(Token::Keyword(Keyword::Pub)) => Visibility::Public, - Some(Token::Keyword(Keyword::CallData)) | Some(Token::Keyword(Keyword::ReturnData)) => { - Visibility::DataBus - } - None => Visibility::Private, - _ => unreachable!("unexpected token found"), - }) + .map(|opt| opt.unwrap_or(Visibility::Private)) } pub fn expression() -> impl ExprParser { diff --git a/test_programs/execution_success/databus/src/main.nr b/test_programs/execution_success/databus/src/main.nr index 7e5c23d508d..1e4aa141eea 100644 --- a/test_programs/execution_success/databus/src/main.nr +++ b/test_programs/execution_success/databus/src/main.nr @@ -1,4 +1,4 @@ -fn main(mut x: u32, y: call_data u32, z: call_data [u32; 4]) -> return_data u32 { +fn main(mut x: u32, y: call_data(0) u32, z: call_data(0) [u32; 4]) -> return_data u32 { let a = z[x]; a + foo(y) } diff --git a/tooling/nargo_fmt/src/utils.rs b/tooling/nargo_fmt/src/utils.rs index 020f411ae2f..293351055e1 100644 --- a/tooling/nargo_fmt/src/utils.rs +++ b/tooling/nargo_fmt/src/utils.rs @@ -146,9 +146,10 @@ impl HasItem for Param { fn format(self, visitor: &FmtVisitor, shape: Shape) -> String { let pattern = visitor.slice(self.pattern.span()); let visibility = match self.visibility { - Visibility::Public => "pub", - Visibility::Private => "", - Visibility::DataBus => "call_data", + Visibility::Public => "pub".to_string(), + Visibility::Private => "".to_string(), + Visibility::CallData(x) => format!("call_data({x})"), + Visibility::ReturnData => "return_data".to_string(), }; if self.pattern.is_synthesized() || self.typ.is_synthesized() { diff --git a/tooling/nargo_fmt/src/visitor/item.rs b/tooling/nargo_fmt/src/visitor/item.rs index 5aaaf20ff47..0c9f61a7d40 100644 --- a/tooling/nargo_fmt/src/visitor/item.rs +++ b/tooling/nargo_fmt/src/visitor/item.rs @@ -120,8 +120,11 @@ impl super::FmtVisitor<'_> { let visibility = match func.def.return_visibility { Visibility::Public => "pub", - Visibility::DataBus => "return_data", + Visibility::ReturnData => "return_data", Visibility::Private => "", + Visibility::CallData(_) => { + unreachable!("call_data cannot be used for return value") + } }; result.push_str(&append_space_if_nonempty(visibility.into()));