diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index 2608cc36..50821348 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -86,6 +86,22 @@ impl PoseidonCompactInput { } } +/// A compact chunk input for Poseidon hasher. The end of a logical input could only be at the boundary of a chunk. +#[derive(Clone, Debug)] +pub struct PoseidonCompactChunkInput { + // Inputs of a chunk. All witnesses will be absorbed. + inputs: Vec<[AssignedValue; RATE]>, + // is_final = 1 triggers squeeze. + is_final: SafeBool, +} + +impl PoseidonCompactChunkInput { + /// Create a new PoseidonCompactInput. + pub fn new(inputs: Vec<[AssignedValue; RATE]>, is_final: SafeBool) -> Self { + Self { inputs, is_final } + } +} + /// 1 logical row of compact output for Poseidon hasher. #[derive(Copy, Clone, Debug, Getters)] pub struct PoseidonCompactOutput { @@ -232,6 +248,36 @@ impl PoseidonHasher, + range: &impl RangeInstructions, + chunk_inputs: &[PoseidonCompactChunkInput], + ) -> Vec> + where + F: BigPrimeField, + { + let zero_witness = ctx.load_zero(); + let mut outputs = Vec::with_capacity(chunk_inputs.len()); + let mut state = self.init_state().clone(); + for chunk_input in chunk_inputs { + let is_final = chunk_input.is_final; + for absorb in &chunk_input.inputs { + state.permutation(ctx, range.gate(), absorb, None, &self.spec); + } + // Because the length of each absorb is always RATE. An extra permutation is needed for squeeze. + let mut output_state = state.clone(); + output_state.permutation(ctx, range.gate(), &[], None, &self.spec); + let hash = + range.gate().select(ctx, output_state.s[1], zero_witness, *is_final.as_ref()); + outputs.push(PoseidonCompactOutput { hash, is_final }); + // Reset state to init_state if this is the end of a logical input. + state.select(ctx, range.gate(), is_final, self.init_state()); + } + outputs + } } /// Poseidon sponge. This is stateful. diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs index 2023c4ec..68207d83 100644 --- a/halo2-base/src/poseidon/hasher/tests/hasher.rs +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -1,12 +1,16 @@ use crate::{ gates::{range::RangeInstructions, RangeChip}, halo2_proofs::halo2curves::bn256::Fr, - poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonCompactInput, PoseidonHasher}, + poseidon::hasher::{ + spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactInput, + PoseidonHasher, + }, safe_types::SafeTypeChip, utils::{testing::base_test, ScalarField}, Context, }; use halo2_proofs_axiom::arithmetic::Field; +use itertools::Itertools; use pse_poseidon::Poseidon; use rand::Rng; @@ -111,6 +115,61 @@ fn hasher_compact_inputs_compatiblity_verification< } } +// check if the results from hasher and native sponge are same for hash_compact_input. +fn hasher_compact_chunk_inputs_compatiblity_verification< + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec<(Payload, bool)>, + ctx: &mut Context, + range: &RangeChip, +) { + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + let mut native_results = Vec::with_capacity(payloads.len()); + let mut chunk_inputs = Vec::>::new(); + let true_witness = SafeTypeChip::unsafe_to_bool(ctx.load_constant(Fr::ONE)); + let false_witness = SafeTypeChip::unsafe_to_bool(ctx.load_zero()); + + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + for (payload, is_final) in payloads { + assert!(payload.values.len() == payload.len); + assert!(payload.values.len() % RATE == 0); + let inputs = ctx.assign_witnesses(payload.values.clone()); + + let is_final_witness = if is_final { true_witness } else { false_witness }; + chunk_inputs.push(PoseidonCompactChunkInput { + inputs: inputs.chunks(RATE).map(|c| c.try_into().unwrap()).collect_vec(), + is_final: is_final_witness, + }); + native_sponge.update(&payload.values); + if is_final { + let native_result = native_sponge.squeeze(); + native_results.push(native_result); + native_sponge = Poseidon::::new(R_F, R_P); + } + } + let compact_outputs = hasher.hash_compact_chunk_inputs(ctx, range, &chunk_inputs); + assert_eq!(chunk_inputs.len(), compact_outputs.len()); + let mut output_offset = 0; + for (compact_output, chunk_input) in compact_outputs.iter().zip(chunk_inputs) { + // into() doesn't work if ! is in the beginning in the bool expression... + let is_final_input = chunk_input.is_final.as_ref().value(); + let is_final_output = compact_output.is_final().as_ref().value(); + assert_eq!(is_final_input, is_final_output); + if is_final_output == &Fr::ONE { + assert_eq!(native_results[output_offset], *compact_output.hash().value()); + output_offset += 1; + } + } +} + fn random_payload(max_len: usize, len: usize, max_value: usize) -> Payload { assert!(len <= max_len); let mut rng = rand::thread_rng(); @@ -235,3 +294,65 @@ fn test_poseidon_hasher_compact_inputs_with_prover() { }); } } + +#[test] +fn test_poseidon_hasher_compact_chunk_inputs() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + (random_payload(RATE * 5, RATE * 5, usize::MAX), true), + (random_payload(RATE, RATE, usize::MAX), false), + (random_payload(RATE * 2, RATE * 2, usize::MAX), true), + (random_payload(RATE * 3, RATE * 3, usize::MAX), true), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_chunk_inputs_compatiblity_verification::( + payloads, ctx, range, + ); + }); + } + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + (random_payload(0, 0, usize::MAX), true), + (random_payload(0, 0, usize::MAX), false), + (random_payload(0, 0, usize::MAX), false), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_chunk_inputs_compatiblity_verification::( + payloads, ctx, range, + ); + }); + } +} + +#[test] +fn test_poseidon_hasher_compact_chunk_inputs_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + let params = [ + (RATE, false), + (RATE * 2, false), + (RATE * 5, false), + (RATE * 2, true), + (RATE * 5, true), + ]; + let init_payloads = params + .iter() + .map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final)) + .collect::>(); + let logic_payloads = params + .iter() + .map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final)) + .collect::>(); + base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| { + let ctx = pool.main(); + hasher_compact_chunk_inputs_compatiblity_verification::( + input, ctx, range, + ); + }); + } +} diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index c0372624..3e7fffea 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -52,6 +52,12 @@ impl VarLenBytes { padded.into_iter().map(|b| SafeByte(b)).collect::>().try_into().unwrap(), ) } + + /// Return a copy of the byte array with 0 padding ensured. + pub fn ensure_0_padding(&self, ctx: &mut Context, gate: &impl GateInstructions) -> Self { + let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len); + Self::new(bytes.try_into().unwrap(), self.len) + } } /// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. @@ -93,7 +99,13 @@ impl VarLenBytesVec { gate: &impl GateInstructions, ) -> FixLenBytesVec { let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len()); - padded.into_iter().map(|b| SafeByte(b)).collect() + FixLenBytesVec::new(padded.into_iter().map(|b| SafeByte(b)).collect_vec(), self.max_len()) + } + + /// Return a copy of the byte array with 0 padding ensured. + pub fn ensure_0_padding(&self, ctx: &mut Context, gate: &impl GateInstructions) -> Self { + let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len); + Self::new(bytes, self.len, self.max_len()) } } @@ -117,6 +129,27 @@ impl FixLenBytes { } } +/// Represents a fixed length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. +#[derive(Debug, Clone, Getters)] +pub struct FixLenBytesVec { + /// The byte array + #[getset(get = "pub")] + bytes: Vec>, +} + +impl FixLenBytesVec { + // FixLenBytes can be only created by SafeChip. + pub(super) fn new(bytes: Vec>, len: usize) -> Self { + assert_eq!(bytes.len(), len, "bytes length doesn't match"); + Self { bytes } + } + + /// Returns the length of the byte array. + pub fn len(&self) -> usize { + self.bytes.len() + } +} + impl From> for FixLenBytes::VALUE_LENGTH }> { @@ -138,7 +171,7 @@ impl /// Represents a fixed length byte array in circuit as a vector, where length must be fixed. /// Not encouraged to use because `LEN` cannot be verified at compile time. -pub type FixLenBytesVec = Vec>; +// pub type FixLenBytesVec = Vec>; /// Takes a fixed length array `arr` and returns a length `out_len` array equal to /// `[[0; out_len - len], arr[..len]].concat()`, i.e., we take `arr[..len]` and @@ -172,3 +205,24 @@ pub fn left_pad_var_array_to_fixed( } padded } + +fn ensure_0_padding( + ctx: &mut Context, + gate: &impl GateInstructions, + bytes: &[SafeByte], + len: AssignedValue, +) -> Vec> { + let max_len = bytes.len(); + // Generate a mask array where a[i] = i < len for i = 0..max_len. + let idx = gate.dec(ctx, len); + let len_indicator = gate.idx_to_indicator(ctx, idx, max_len); + // inputs_mask[i] = sum(len_indicator[i..]) + let mut mask = gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec(); + mask.reverse(); + + bytes + .iter() + .zip(mask.iter()) + .map(|(byte, mask)| SafeByte(gate.mul(ctx, byte.0, *mask))) + .collect_vec() +} diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index c34b2a51..32171c53 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -228,6 +228,18 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { FixLenBytes::::new(inputs.map(|input| Self::unsafe_to_byte(input))) } + /// Unsafe method that directly converts `inputs` to [`FixLenBytesVec`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_fix_len_bytes_vec( + inputs: RawAssignedValues, + len: usize, + ) -> FixLenBytesVec { + FixLenBytesVec::::new( + inputs.into_iter().map(|input| Self::unsafe_to_byte(input)).collect_vec(), + len, + ) + } + /// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes. /// /// * ctx: Circuit [Context] to assign witnesses to. @@ -249,7 +261,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// * ctx: Circuit [Context] to assign witnesses to. /// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding. /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= max_len`. - /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. + /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. pub fn raw_to_var_len_bytes_vec( &self, ctx: &mut Context, @@ -278,6 +290,23 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { FixLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input))) } + /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytesVec. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Slice representing the byte array. + /// * len: length of the byte array. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. + pub fn raw_to_fix_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: RawAssignedValues, + len: usize, + ) -> FixLenBytesVec { + FixLenBytesVec::::new( + inputs.into_iter().map(|input| self.assert_byte(ctx, input)).collect_vec(), + len, + ) + } + fn add_bytes_constraints( &self, ctx: &mut Context, diff --git a/halo2-base/src/safe_types/tests/bytes.rs b/halo2-base/src/safe_types/tests/bytes.rs index 966dffb4..9c24444f 100644 --- a/halo2-base/src/safe_types/tests/bytes.rs +++ b/halo2-base/src/safe_types/tests/bytes.rs @@ -55,7 +55,7 @@ fn left_pad_var_len_bytes(mut bytes: Vec, max_len: usize) -> Vec { let len = ctx.load_witness(Fr::from(len as u64)); let bytes = safe.raw_to_var_len_bytes_vec(ctx, bytes, len, max_len); let padded = bytes.left_pad_to_fixed(ctx, range.gate()); - padded.iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect() + padded.bytes().iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect() }) } @@ -132,7 +132,7 @@ fn neg_var_len_bytes_vec_len_less_than_max_len() { // Circuit Satisfied for valid inputs #[test] -fn pos_fix_len_bytes_vec() { +fn pos_fix_len_bytes() { base_test().k(10).lookup_bits(8).run(|ctx, range| { let safe = SafeTypeChip::new(range); let fake_bytes = ctx.assign_witnesses( @@ -142,6 +142,31 @@ fn pos_fix_len_bytes_vec() { }); } +// Assert inputs.len() == len +#[test] +#[should_panic] +fn neg_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 5); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 4); + }); +} + // =========== Prover =========== #[test] fn pos_prover_satisfied() { diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs index 6d4169e4..63a8945a 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs @@ -360,7 +360,7 @@ impl KeccakCoprocessorLeafCircuit { let mut circuit_final_outputs = Vec::with_capacity(loaded_keccak_fs.len()); for (compact_output, loaded_keccak_f) in - lookup_key_per_keccak_f.iter().zip(loaded_keccak_fs) + lookup_key_per_keccak_f.iter().zip_eq(loaded_keccak_fs) { let is_final = AssignedValue::from(loaded_keccak_f.is_final); let key = gate.select(ctx, *compact_output.hash(), dummy_key_witness, is_final); @@ -413,7 +413,7 @@ impl KeccakCoprocessorLeafCircuit { } } -fn create_hasher() -> PoseidonHasher { +pub(crate) fn create_hasher() -> PoseidonHasher { // Construct in-circuit Poseidon hasher. let spec = OptimizedPoseidonSpec::::new::< POSEIDON_R_F, @@ -491,6 +491,7 @@ pub fn encode_inputs_from_keccak_fs( last_is_final = is_final.into(); } + // TODO: use hash_compact_chunk_input instead. let compact_outputs = initialized_hasher.hash_compact_input(ctx, gate, &compact_inputs); compact_outputs diff --git a/hashes/zkevm/src/keccak/coprocessor/encode.rs b/hashes/zkevm/src/keccak/coprocessor/encode.rs index 4922b817..cfba6de6 100644 --- a/hashes/zkevm/src/keccak/coprocessor/encode.rs +++ b/hashes/zkevm/src/keccak/coprocessor/encode.rs @@ -1,6 +1,18 @@ +use halo2_base::{ + gates::{GateInstructions, RangeInstructions}, + poseidon::hasher::{PoseidonCompactChunkInput, PoseidonHasher}, + safe_types::{FixLenBytesVec, SafeByte, SafeTypeChip, VarLenBytesVec}, + utils::bit_length, + AssignedValue, Context, + QuantumCell::Constant, +}; use itertools::Itertools; +use num_bigint::BigUint; -use crate::{keccak::vanilla::param::*, util::eth_types::Field}; +use crate::{ + keccak::vanilla::{keccak_packed_multi::get_num_keccak_f, param::*}, + util::eth_types::Field, +}; use super::param::*; @@ -31,7 +43,7 @@ pub fn encode_native_input(bytes: &[u8]) -> F { } // 1. Split Keccak words into keccak_fs(each keccak_f has NUM_WORDS_TO_ABSORB). // 2. Append an extra word into the beginning of each keccak_f. In the first keccak_f, this word is the byte length of the input. Otherwise 0. - let words_per_chunk = words + let words_per_keccak_f = words .chunks(NUM_WORDS_TO_ABSORB) .enumerate() .map(|(i, chunk)| { @@ -42,7 +54,7 @@ pub fn encode_native_input(bytes: &[u8]) -> F { }) .collect_vec(); // Compress every num_word_per_witness words into a witness. - let witnesses_per_chunk = words_per_chunk + let witnesses_per_keccak_f = words_per_keccak_f .iter() .map(|chunk| { chunk @@ -58,7 +70,7 @@ pub fn encode_native_input(bytes: &[u8]) -> F { // Absorb witnesses keccak_f by keccak_f. let mut native_poseidon_sponge = pse_poseidon::Poseidon::::new(POSEIDON_R_F, POSEIDON_R_P); - for witnesses in witnesses_per_chunk { + for witnesses in witnesses_per_keccak_f { for absorbing in witnesses.chunks(POSEIDON_RATE) { // To avoid absorbing witnesses crossing keccak_fs together, pad 0s to make sure absorb.len() == RATE. let mut padded_absorb = [F::ZERO; POSEIDON_RATE]; @@ -69,7 +81,60 @@ pub fn encode_native_input(bytes: &[u8]) -> F { native_poseidon_sponge.squeeze() } -// TODO: Add a function to encode a VarLenBytes into a lookup key. The function should be used by App Circuits. +/// Encode a VarLenBytesVec into its corresponding lookup key. +pub fn encode_var_len_bytes_vec( + ctx: &mut Context, + range_chip: &impl RangeInstructions, + initialized_hasher: &PoseidonHasher, + bytes: &VarLenBytesVec, +) -> AssignedValue { + let max_len = bytes.max_len(); + let max_num_keccak_f = get_num_keccak_f(max_len); + // num_keccak_f = len / NUM_BYTES_TO_ABSORB + 1 + let num_bits = bit_length(max_len as u64); + let (num_keccak_f, _) = + range_chip.div_mod(ctx, *bytes.len(), BigUint::from(NUM_BYTES_TO_ABSORB), num_bits); + let f_indicator = range_chip.gate().idx_to_indicator(ctx, num_keccak_f, max_num_keccak_f); + + let bytes = bytes.ensure_0_padding(ctx, range_chip.gate()); + let chunk_input_per_f = format_input(ctx, range_chip.gate(), bytes.bytes(), *bytes.len()); + + let chunk_inputs = chunk_input_per_f + .into_iter() + .zip(&f_indicator) + .map(|(chunk_input, is_final)| { + let is_final = SafeTypeChip::unsafe_to_bool(*is_final); + PoseidonCompactChunkInput::new(chunk_input, is_final) + }) + .collect_vec(); + + let compact_outputs = + initialized_hasher.hash_compact_chunk_inputs(ctx, range_chip, &chunk_inputs); + range_chip.gate().select_by_indicator( + ctx, + compact_outputs.into_iter().map(|o| *o.hash()), + f_indicator, + ) +} + +/// Encode a FixLenBytesVec into its corresponding lookup key. +pub fn encode_fix_len_bytes_vec( + ctx: &mut Context, + gate_chip: &impl GateInstructions, + initialized_hasher: &PoseidonHasher, + bytes: &FixLenBytesVec, +) -> AssignedValue { + // Constant witnesses + let len_witness = ctx.load_constant(F::from(bytes.len() as u64)); + + let chunk_input_per_f = format_input(ctx, gate_chip, bytes.bytes(), len_witness); + let flatten_inputs = chunk_input_per_f + .into_iter() + .flat_map(|chunk_input| chunk_input.into_iter().flatten()) + .collect_vec(); + + initialized_hasher.hash_fix_len_array(ctx, gate_chip, &flatten_inputs) +} // For reference, when F is bn254::Fr: // num_word_per_witness = 3 @@ -114,3 +179,78 @@ pub(crate) fn get_words_to_witness_multipliers() -> Vec { } multipliers } + +pub(crate) fn get_bytes_to_words_multipliers() -> Vec { + let mut multiplier_f = F::ONE; + let mut multipliers = Vec::with_capacity(NUM_BYTES_PER_WORD); + multipliers.push(multiplier_f); + let base_f = F::from_u128(1 << NUM_BITS_PER_BYTE); + for _ in 1..NUM_BYTES_PER_WORD { + multiplier_f *= base_f; + multipliers.push(multiplier_f); + } + multipliers +} + +fn format_input( + ctx: &mut Context, + gate: &impl GateInstructions, + bytes: &[SafeByte], + len: AssignedValue, +) -> Vec; POSEIDON_RATE]>> { + // Constant witnesses + let zero_const = ctx.load_zero(); + let bytes_to_words_multipliers_val = + get_bytes_to_words_multipliers::().into_iter().map(|m| Constant(m)).collect_vec(); + let words_to_witness_multipliers_val = + get_words_to_witness_multipliers::().into_iter().map(|m| Constant(m)).collect_vec(); + + let mut bytes_witnesses = bytes.to_vec(); + // Append a zero to the end because An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + bytes_witnesses.push(SafeTypeChip::unsafe_to_byte(zero_const)); + let words = bytes_witnesses + .chunks(NUM_BYTES_PER_WORD) + .map(|c| { + let len = c.len(); + let multipliers = bytes_to_words_multipliers_val[..len].to_vec(); + gate.inner_product(ctx, c.iter().map(|sb| *sb.as_ref()), multipliers) + }) + .collect_vec(); + + let words_per_f = words + .chunks(NUM_WORDS_TO_ABSORB) + .enumerate() + .map(|(i, words_per_f)| { + let mut buffer = [zero_const; NUM_WORDS_TO_ABSORB + 1]; + buffer[0] = if i == 0 { len } else { zero_const }; + buffer[1..words_per_f.len() + 1].copy_from_slice(words_per_f); + buffer + }) + .collect_vec(); + + let witnesses_per_f = words_per_f + .iter() + .map(|words| { + words + .chunks(num_word_per_witness::()) + .map(|c| { + gate.inner_product(ctx, c.to_vec(), words_to_witness_multipliers_val.clone()) + }) + .collect_vec() + }) + .collect_vec(); + + witnesses_per_f + .iter() + .map(|words| { + words + .chunks(POSEIDON_RATE) + .map(|c| { + let mut buffer = [zero_const; POSEIDON_RATE]; + buffer[..c.len()].copy_from_slice(c); + buffer + }) + .collect_vec() + }) + .collect_vec() +} diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs b/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs new file mode 100644 index 00000000..761a4e9a --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs @@ -0,0 +1,124 @@ +use ethers_core::k256::elliptic_curve::Field; +use halo2_base::{ + gates::{GateInstructions, RangeChip, RangeInstructions}, + halo2_proofs::halo2curves::bn256::Fr, + safe_types::SafeTypeChip, + utils::testing::base_test, + Context, +}; +use itertools::Itertools; + +use crate::keccak::coprocessor::{ + circuit::leaf::create_hasher, + encode::{encode_fix_len_bytes_vec, encode_native_input, encode_var_len_bytes_vec}, +}; + +fn build_and_verify_encode_var_len_bytes_vec( + inputs: Vec<(Vec, usize)>, + ctx: &mut Context, + range_chip: &RangeChip, +) { + let mut hasher = create_hasher(); + hasher.initialize_consts(ctx, range_chip.gate()); + + for (input, max_len) in inputs { + let expected = encode_native_input::(&input); + let len = ctx.load_witness(Fr::from(input.len() as u64)); + let mut witnesses_val = vec![Fr::ZERO; max_len]; + witnesses_val[..input.len()] + .copy_from_slice(&input.iter().map(|b| Fr::from(*b as u64)).collect_vec()); + let input_witnesses = ctx.assign_witnesses(witnesses_val); + let var_len_bytes_vec = + SafeTypeChip::unsafe_to_var_len_bytes_vec(input_witnesses, len, max_len); + let encoded = encode_var_len_bytes_vec(ctx, range_chip, &hasher, &var_len_bytes_vec); + assert_eq!(encoded.value(), &expected); + } +} + +fn build_and_verify_encode_fix_len_bytes_vec( + inputs: Vec>, + ctx: &mut Context, + gate_chip: &impl GateInstructions, +) { + let mut hasher = create_hasher(); + hasher.initialize_consts(ctx, gate_chip); + + for input in inputs { + let expected = encode_native_input::(&input); + let len = input.len(); + let witnesses_val = input.into_iter().map(|b| Fr::from(b as u64)).collect_vec(); + let input_witnesses = ctx.assign_witnesses(witnesses_val); + let fix_len_bytes_vec = SafeTypeChip::unsafe_to_fix_len_bytes_vec(input_witnesses, len); + let encoded = encode_fix_len_bytes_vec(ctx, gate_chip, &hasher, &fix_len_bytes_vec); + assert_eq!(encoded.value(), &expected); + } +} + +#[test] +fn mock_encode_var_len_bytes_vec() { + let inputs = vec![ + (vec![], 1), + (vec![], 136), + ((1u8..135).collect_vec(), 136), + ((1u8..135).collect_vec(), 134), + ((1u8..135).collect_vec(), 137), + ((1u8..135).collect_vec(), 272), + ((1u8..135).collect_vec(), 136 * 3), + ]; + base_test().k(18).lookup_bits(4).run(|ctx: &mut Context, range_chip: &RangeChip| { + build_and_verify_encode_var_len_bytes_vec(inputs, ctx, range_chip); + }) +} + +#[test] +fn prove_encode_var_len_bytes_vec() { + let init_inputs = vec![ + (vec![], 1), + (vec![], 136), + (vec![], 136), + (vec![], 137), + (vec![], 272), + (vec![], 136 * 3), + ]; + let inputs = vec![ + (vec![], 1), + (vec![], 136), + ((1u8..135).collect_vec(), 136), + ((1u8..135).collect_vec(), 137), + ((1u8..135).collect_vec(), 272), + ((1u8..135).collect_vec(), 136 * 3), + ]; + base_test().k(18).lookup_bits(4).bench_builder( + init_inputs, + inputs, + |core, range_chip, inputs| { + let ctx = core.main(); + build_and_verify_encode_var_len_bytes_vec(inputs, ctx, range_chip); + }, + ); +} + +#[test] +fn mock_encode_fix_len_bytes_vec() { + let inputs = + vec![vec![], (1u8..135).collect_vec(), (0u8..136).collect_vec(), (0u8..211).collect_vec()]; + base_test().k(18).lookup_bits(4).run(|ctx: &mut Context, range_chip: &RangeChip| { + build_and_verify_encode_fix_len_bytes_vec(inputs, ctx, range_chip.gate()); + }); +} + +#[test] +fn prove_encode_fix_len_bytes_vec() { + let init_inputs = + vec![vec![], (2u8..136).collect_vec(), (1u8..137).collect_vec(), (2u8..213).collect_vec()]; + let inputs = + vec![vec![], (1u8..135).collect_vec(), (0u8..136).collect_vec(), (0u8..211).collect_vec()]; + base_test().k(18).lookup_bits(4).bench_builder( + init_inputs, + inputs, + |core, range_chip, inputs| { + let ctx = core.main(); + build_and_verify_encode_fix_len_bytes_vec(inputs, ctx, range_chip.gate()); + }, + ); +} diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs b/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs index 63c4e272..520b3573 100644 --- a/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs +++ b/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs @@ -1,2 +1,4 @@ #[cfg(test)] +mod encode; +#[cfg(test)] mod output; diff --git a/hashes/zkevm/src/keccak/vanilla/mod.rs b/hashes/zkevm/src/keccak/vanilla/mod.rs index 90c461a4..b6941153 100644 --- a/hashes/zkevm/src/keccak/vanilla/mod.rs +++ b/hashes/zkevm/src/keccak/vanilla/mod.rs @@ -592,7 +592,7 @@ impl KeccakCircuitConfig { let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); let masked_input_bytes = input_bytes .iter() - .zip(is_paddings.clone()) + .zip_eq(is_paddings.clone()) .map(|(input_byte, is_padding)| { input_byte.expr.clone() * not::expr(is_padding.expr().clone()) })