diff --git a/acir/src/circuit/opcodes.rs b/acir/src/circuit/opcodes.rs index c0a394bae..07c7b99f8 100644 --- a/acir/src/circuit/opcodes.rs +++ b/acir/src/circuit/opcodes.rs @@ -238,7 +238,7 @@ fn serialization_roundtrip() { let opcode_arith = Opcode::Arithmetic(Expression::default()); - let opcode_black_box_func = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AES { + let aes_black_box_func = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AES { inputs: vec![ FunctionInput { witness: Witness(1u32), num_bits: 12 }, FunctionInput { witness: Witness(24u32), num_bits: 32 }, @@ -246,10 +246,24 @@ fn serialization_roundtrip() { outputs: vec![Witness(123u32), Witness(245u32)], }); + let ecdsa_black_box_func = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::EcdsaSecp256k1 { + public_key_x: vec![ + FunctionInput { witness: Witness(10u32), num_bits: 8 }, + FunctionInput { witness: Witness(11u32), num_bits: 8 }, + ], + public_key_y: vec![ + FunctionInput { witness: Witness(12u32), num_bits: 8 }, + FunctionInput { witness: Witness(13u32), num_bits: 8 }, + ], + signature: vec![FunctionInput { witness: Witness(14u32), num_bits: 8 }], + hashed_message: vec![FunctionInput { witness: Witness(15u32), num_bits: 8 }], + output: Witness(300u32), + }); + let opcode_directive = Opcode::Directive(Directive::Invert { x: Witness(1234u32), result: Witness(56789u32) }); - let opcodes = vec![opcode_arith, opcode_black_box_func, opcode_directive]; + let opcodes = vec![opcode_arith, aes_black_box_func, opcode_directive, ecdsa_black_box_func]; for opcode in opcodes { let (op, got_op) = read_write(opcode); diff --git a/acir/src/circuit/opcodes/black_box_function_call.rs b/acir/src/circuit/opcodes/black_box_function_call.rs index d53486b07..119e759d1 100644 --- a/acir/src/circuit/opcodes/black_box_function_call.rs +++ b/acir/src/circuit/opcodes/black_box_function_call.rs @@ -5,8 +5,8 @@ use crate::serialization::{read_u16, read_u32, write_u16, write_u32}; use crate::BlackBoxFunc; use serde::{Deserialize, Serialize}; -// Note: Some functions will not use all of the witness -// So we need to supply how many bits of the witness is needed +/// Note: Some functions will not use all of the witness +/// So we need to supply how many bits of the witness is needed #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct FunctionInput { pub witness: Witness, @@ -19,6 +19,30 @@ impl FunctionInput { } } +/// A BlackBoxFuncCall can have multiple fields representing the inputs, +/// which can vary in length depending on the backend system. +/// There needs to be some length associated with each input field during serialization in order +/// to keep each function input length from being hardcoded. Singular function inputs +/// are not affected as the function signature determines whether we should fetch the first element +/// of a vector when reading in the function inputs or return a vector with a single element +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +pub struct FunctionInputIO { + length: usize, + inner: Vec, +} + +impl From> for FunctionInputIO { + fn from(func_input: Vec) -> Self { + FunctionInputIO { length: func_input.len(), inner: func_input } + } +} + +impl From for FunctionInputIO { + fn from(func_input: FunctionInput) -> Self { + FunctionInputIO { length: 1, inner: vec![func_input] } + } +} + #[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub enum BlackBoxFuncCall { #[allow(clippy::upper_case_acronyms)] @@ -93,12 +117,17 @@ fn write_input(input: &FunctionInput, mut writer: W) -> std::io::Resul Ok(()) } -fn write_inputs(inputs: &[FunctionInput], mut writer: W) -> std::io::Result<()> { +fn write_inputs(inputs: &[FunctionInputIO], mut writer: W) -> std::io::Result<()> { let num_inputs = inputs.len() as u32; write_u32(&mut writer, num_inputs)?; - for input in inputs { - write_input(input, &mut writer)?; + for input_io_info in inputs { + let input_length = input_io_info.length as u32; + write_u32(&mut writer, input_length)?; + let inner_inputs = &input_io_info.inner; + for input in inner_inputs { + write_input(input, &mut writer)?; + } } Ok(()) @@ -121,14 +150,22 @@ fn read_input(mut reader: R) -> std::io::Result { Ok(FunctionInput { witness: Witness::new(witness_index), num_bits }) } -fn read_inputs(mut reader: R) -> std::io::Result> { +fn read_inputs(mut reader: R) -> std::io::Result> { let num_inputs = read_u32(&mut reader)?; let mut inputs = Vec::new(); inputs.try_reserve_exact(num_inputs as usize).map_err(|_| std::io::ErrorKind::InvalidData)?; for _ in 0..num_inputs { - inputs.push(read_input(&mut reader)?); + let input_length = read_u32(&mut reader)?; + let mut inner_inputs = Vec::new(); + inner_inputs + .try_reserve_exact(input_length as usize) + .map_err(|_| std::io::ErrorKind::InvalidData)?; + for _ in 0..input_length { + inner_inputs.push(read_input(&mut reader)?); + } + inputs.push(FunctionInputIO { length: input_length as usize, inner: inner_inputs }); } Ok(inputs) @@ -223,25 +260,23 @@ impl BlackBoxFuncCall { self.get_black_box_func().name() } - pub fn get_inputs_vec(&self) -> Vec { + pub fn get_inputs_io_vec(&self) -> Vec { match self { BlackBoxFuncCall::AES { inputs, .. } | BlackBoxFuncCall::SHA256 { inputs, .. } | BlackBoxFuncCall::Blake2s { inputs, .. } | BlackBoxFuncCall::Keccak256 { inputs, .. } | BlackBoxFuncCall::Pedersen { inputs, .. } - | BlackBoxFuncCall::HashToField128Security { inputs, .. } => inputs.to_vec(), + | BlackBoxFuncCall::HashToField128Security { inputs, .. } => { + vec![inputs.clone().into()] + } BlackBoxFuncCall::AND { lhs, rhs, .. } | BlackBoxFuncCall::XOR { lhs, rhs, .. } => { - vec![*lhs, *rhs] + vec![(*lhs).into(), (*rhs).into()] } BlackBoxFuncCall::FixedBaseScalarMul { input, .. } - | BlackBoxFuncCall::RANGE { input } => vec![*input], + | BlackBoxFuncCall::RANGE { input } => vec![(*input).into()], BlackBoxFuncCall::ComputeMerkleRoot { leaf, index, hash_path, .. } => { - let mut inputs = Vec::with_capacity(2 + hash_path.len()); - inputs.push(*leaf); - inputs.push(*index); - inputs.extend(hash_path.iter().copied()); - inputs + vec![(*leaf).into(), (*index).into(), hash_path.clone().into()] } BlackBoxFuncCall::SchnorrVerify { public_key_x, @@ -249,33 +284,34 @@ impl BlackBoxFuncCall { signature, message, .. - } => { - let mut inputs = Vec::with_capacity(2 + signature.len() + message.len()); - inputs.push(*public_key_x); - inputs.push(*public_key_y); - inputs.extend(signature.iter().copied()); - inputs.extend(message.iter().copied()); - inputs - } + } => vec![ + (*public_key_x).into(), + (*public_key_y).into(), + signature.clone().into(), + message.clone().into(), + ], BlackBoxFuncCall::EcdsaSecp256k1 { public_key_x, public_key_y, signature, hashed_message: message, .. - } => { - let mut inputs = Vec::with_capacity( - public_key_x.len() + public_key_y.len() + signature.len() + message.len(), - ); - inputs.extend(public_key_x.iter().copied()); - inputs.extend(public_key_y.iter().copied()); - inputs.extend(signature.iter().copied()); - inputs.extend(message.iter().copied()); - inputs - } + } => vec![ + public_key_x.clone().into(), + public_key_y.clone().into(), + signature.clone().into(), + message.clone().into(), + ], } } + /// A flattened vector of all the function inputs + /// Used for displaying black box funcs and gadget simplification + pub fn get_inputs_vec(&self) -> Vec { + let inputs_io = self.get_inputs_io_vec(); + inputs_io.iter().flat_map(|io_obj| io_obj.inner.clone()).collect::>() + } + pub fn get_outputs_vec(&self) -> Vec { match self { BlackBoxFuncCall::AES { outputs, .. } @@ -297,7 +333,7 @@ impl BlackBoxFuncCall { pub fn write(&self, mut writer: W) -> std::io::Result<()> { write_u16(&mut writer, self.get_black_box_func().to_u16())?; - write_inputs(&self.get_inputs_vec(), &mut writer)?; + write_inputs(&self.get_inputs_io_vec(), &mut writer)?; write_outputs(&self.get_outputs_vec(), &mut writer)?; Ok(()) @@ -311,13 +347,15 @@ impl BlackBoxFuncCall { let outputs = read_outputs(&mut reader)?; match name { - BlackBoxFunc::AES => Ok(BlackBoxFuncCall::AES { inputs, outputs }), + BlackBoxFunc::AES => { + Ok(BlackBoxFuncCall::AES { inputs: inputs[0].inner.clone(), outputs }) + } BlackBoxFunc::AND => { if inputs.len() != 2 || outputs.len() != 1 { Err(std::io::ErrorKind::InvalidData.into()) } else { - let lhs = inputs[0]; - let rhs = inputs[1]; + let lhs = inputs[0].inner[0]; + let rhs = inputs[1].inner[0]; let output = outputs[0]; Ok(BlackBoxFuncCall::AND { lhs, rhs, output }) } @@ -326,8 +364,8 @@ impl BlackBoxFuncCall { if inputs.len() != 2 || outputs.len() != 1 { Err(std::io::ErrorKind::InvalidData.into()) } else { - let lhs = inputs[0]; - let rhs = inputs[1]; + let lhs = inputs[0].inner[0]; + let rhs = inputs[1].inner[0]; let output = outputs[0]; Ok(BlackBoxFuncCall::XOR { lhs, rhs, output }) } @@ -336,53 +374,62 @@ impl BlackBoxFuncCall { if inputs.len() != 1 { Err(std::io::ErrorKind::InvalidData.into()) } else { - Ok(BlackBoxFuncCall::RANGE { input: inputs[0] }) + Ok(BlackBoxFuncCall::RANGE { input: inputs[0].inner[0] }) } } - BlackBoxFunc::SHA256 => Ok(BlackBoxFuncCall::SHA256 { inputs, outputs }), - BlackBoxFunc::Blake2s => Ok(BlackBoxFuncCall::Blake2s { inputs, outputs }), + BlackBoxFunc::SHA256 => { + Ok(BlackBoxFuncCall::SHA256 { inputs: inputs[0].inner.clone(), outputs }) + } + BlackBoxFunc::Blake2s => { + Ok(BlackBoxFuncCall::Blake2s { inputs: inputs[0].inner.clone(), outputs }) + } BlackBoxFunc::ComputeMerkleRoot => { if inputs.len() < 2 || outputs.len() != 1 { Err(std::io::ErrorKind::InvalidData.into()) } else { Ok(BlackBoxFuncCall::ComputeMerkleRoot { - leaf: inputs[0], - index: inputs[1], - hash_path: inputs[2..].to_vec(), + leaf: inputs[0].inner[0], + index: inputs[1].inner[1], + hash_path: inputs[2].inner.clone(), output: outputs[0], }) } } BlackBoxFunc::SchnorrVerify => { - if inputs.len() < 66 || outputs.len() != 1 { + if inputs.len() < 4 || outputs.len() != 1 { Err(std::io::ErrorKind::InvalidData.into()) } else { Ok(BlackBoxFuncCall::SchnorrVerify { - public_key_x: inputs[0], - public_key_y: inputs[1], - signature: inputs[2..66].to_vec(), - message: inputs[66..].to_vec(), + public_key_x: inputs[0].inner[0], + public_key_y: inputs[1].inner[0], + signature: inputs[2].inner.clone(), + message: inputs[3].inner.clone(), output: outputs[0], }) } } - BlackBoxFunc::Pedersen => Ok(BlackBoxFuncCall::Pedersen { inputs, outputs }), + BlackBoxFunc::Pedersen => { + Ok(BlackBoxFuncCall::Pedersen { inputs: inputs[0].inner.clone(), outputs }) + } BlackBoxFunc::HashToField128Security => { if outputs.len() != 1 { Err(std::io::ErrorKind::InvalidData.into()) } else { - Ok(BlackBoxFuncCall::HashToField128Security { inputs, output: outputs[0] }) + Ok(BlackBoxFuncCall::HashToField128Security { + inputs: inputs[0].inner.clone(), + output: outputs[0], + }) } } BlackBoxFunc::EcdsaSecp256k1 => { - if inputs.len() < 128 || outputs.len() != 1 { + if inputs.len() < 4 || outputs.len() != 1 { Err(std::io::ErrorKind::InvalidData.into()) } else { Ok(BlackBoxFuncCall::EcdsaSecp256k1 { - public_key_x: inputs[0..32].to_vec(), - public_key_y: inputs[32..64].to_vec(), - signature: inputs[64..128].to_vec(), - hashed_message: inputs[128..].to_vec(), + public_key_x: inputs[0].inner.clone(), + public_key_y: inputs[1].inner.clone(), + signature: inputs[2].inner.clone(), + hashed_message: inputs[3].inner.clone(), output: outputs[0], }) } @@ -391,10 +438,12 @@ impl BlackBoxFuncCall { if inputs.len() != 1 { Err(std::io::ErrorKind::InvalidData.into()) } else { - Ok(BlackBoxFuncCall::FixedBaseScalarMul { input: inputs[0], outputs }) + Ok(BlackBoxFuncCall::FixedBaseScalarMul { input: inputs[0].inner[0], outputs }) } } - BlackBoxFunc::Keccak256 => Ok(BlackBoxFuncCall::Keccak256 { inputs, outputs }), + BlackBoxFunc::Keccak256 => { + Ok(BlackBoxFuncCall::Keccak256 { inputs: inputs[0].inner.clone(), outputs }) + } } } }