diff --git a/acvm/Cargo.toml b/acvm/Cargo.toml index 6c99ab842..365a2009e 100644 --- a/acvm/Cargo.toml +++ b/acvm/Cargo.toml @@ -18,6 +18,7 @@ acir.workspace = true stdlib.workspace = true sha2 = "0.9.3" +sha3 = "0.10.7" crc32fast = "1.3.2" k256 = { version = "0.7.2", features = [ "ecdsa", diff --git a/acvm/src/lib.rs b/acvm/src/lib.rs index 7b75c1d0c..22cd0f833 100644 --- a/acvm/src/lib.rs +++ b/acvm/src/lib.rs @@ -16,7 +16,9 @@ use acir::{ native_types::{Expression, Witness}, BlackBoxFunc, }; -use pwg::{block::Blocks, directives::solve_directives}; +use pwg::{ + black_box_functions::solve_black_box_function, block::Blocks, directives::solve_directives, +}; use std::collections::BTreeMap; use thiserror::Error; @@ -122,7 +124,12 @@ pub trait PartialWitnessGenerator { unassigned_witness.0, ))) } else { - self.solve_black_box_function_call(initial_witness, bb_func) + let status = solve_black_box_function(initial_witness, bb_func); + if matches!(status, Err(OpcodeResolutionError::OpcodeNotSolvable(_))) { + self.solve_black_box_function_call(initial_witness, bb_func) + } else { + status + } } } Opcode::Directive(directive) => solve_directives(initial_witness, directive), diff --git a/acvm/src/pwg.rs b/acvm/src/pwg.rs index bf7bbdab5..a4d6c9aa5 100644 --- a/acvm/src/pwg.rs +++ b/acvm/src/pwg.rs @@ -14,6 +14,7 @@ pub mod arithmetic; // Directives pub mod directives; // black box functions +pub mod black_box_functions; pub mod block; pub mod hash; pub mod logic; diff --git a/acvm/src/pwg/black_box_functions.rs b/acvm/src/pwg/black_box_functions.rs new file mode 100644 index 000000000..db775f5b2 --- /dev/null +++ b/acvm/src/pwg/black_box_functions.rs @@ -0,0 +1,34 @@ +use std::collections::BTreeMap; + +use acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness, BlackBoxFunc, FieldElement}; + +use crate::{OpcodeNotSolvable, OpcodeResolution, OpcodeResolutionError}; + +use super::hash; + +pub fn solve_black_box_function( + initial_witness: &mut BTreeMap, + func_call: &BlackBoxFuncCall, +) -> Result { + match func_call.name { + BlackBoxFunc::AES + | BlackBoxFunc::AND + | BlackBoxFunc::XOR + | BlackBoxFunc::RANGE + | BlackBoxFunc::SHA256 + | BlackBoxFunc::Blake2s + | BlackBoxFunc::ComputeMerkleRoot + | BlackBoxFunc::SchnorrVerify + | BlackBoxFunc::Pedersen + | BlackBoxFunc::HashToField128Security + | BlackBoxFunc::EcdsaSecp256k1 + | BlackBoxFunc::FixedBaseScalarMul => { + Err(OpcodeResolutionError::OpcodeNotSolvable(OpcodeNotSolvable::MissingAssignment(0))) + } + //self.solve_black_box_function_call(initial_witness, func_call), + BlackBoxFunc::Keccak256 => { + hash::keccak256(initial_witness, func_call)?; + Ok(OpcodeResolution::Solved) + } + } +} diff --git a/acvm/src/pwg/hash.rs b/acvm/src/pwg/hash.rs index d636da8e7..750d87e0a 100644 --- a/acvm/src/pwg/hash.rs +++ b/acvm/src/pwg/hash.rs @@ -1,6 +1,7 @@ use acir::{circuit::opcodes::BlackBoxFuncCall, native_types::Witness, FieldElement}; use blake2::{Blake2s, Digest}; use sha2::Sha256; +use sha3::Keccak256; use std::collections::BTreeMap; use crate::{OpcodeResolution, OpcodeResolutionError}; @@ -72,3 +73,43 @@ fn generic_hash_256( let result = hasher.finalize().as_slice().try_into().unwrap(); Ok(result) } + +pub fn keccak256( + initial_witness: &mut BTreeMap, + gadget_call: &BlackBoxFuncCall, +) -> Result { + generic_sha3::(initial_witness, gadget_call)?; + Ok(OpcodeResolution::Solved) +} + +fn generic_sha3( + initial_witness: &mut BTreeMap, + gadget_call: &BlackBoxFuncCall, +) -> Result<(), OpcodeResolutionError> { + let mut hasher = D::new(); + + // For each input in the vector of inputs, check if we have their witness assignments (Can do this outside of match, since they all have inputs) + for input_index in gadget_call.inputs.iter() { + let witness = &input_index.witness; + let num_bits = input_index.num_bits; + + let witness_assignment = initial_witness.get(witness); + let assignment = match witness_assignment { + None => panic!("cannot find witness assignment for {witness:?}"), + Some(assignment) => assignment, + }; + + let bytes = assignment.fetch_nearest_bytes(num_bits as usize); + hasher.update(bytes); + } + let result = hasher.finalize(); + assert_eq!(result.len(), 32); + for i in 0..32 { + insert_value( + &gadget_call.outputs[i], + FieldElement::from_be_bytes_reduce(&[result[i]]), + initial_witness, + )?; + } + Ok(()) +}