Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions halo2-base/src/poseidon/hasher/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ impl<F: ScalarField, const RATE: usize> PoseidonCompactInput<F, RATE> {
}
}

/// A compact chunk input for Poseidon hasher. The end of a logical input could only be at the boundary of a chunk.
#[derive(Clone, Debug)]
pub struct PoseidonCompactChunkInput<F: ScalarField, const RATE: usize> {
// Inputs of a chunk. All witnesses will be absorbed.
inputs: Vec<[AssignedValue<F>; RATE]>,
// is_final = 1 triggers squeeze.
is_final: SafeBool<F>,
}

impl<F: ScalarField, const RATE: usize> PoseidonCompactChunkInput<F, RATE> {
/// Create a new PoseidonCompactInput.
pub fn new(inputs: Vec<[AssignedValue<F>; RATE]>, is_final: SafeBool<F>) -> Self {
Self { inputs, is_final }
}
}

/// 1 logical row of compact output for Poseidon hasher.
#[derive(Copy, Clone, Debug, Getters)]
pub struct PoseidonCompactOutput<F: ScalarField> {
Expand Down Expand Up @@ -232,6 +248,36 @@ impl<F: ScalarField, const T: usize, const RATE: usize> PoseidonHasher<F, T, RAT
}
outputs
}

/// Constrains and returns hashes of chunk inputs in a compact format. Length of `chunk_inputs` should be determined at compile time.
pub fn hash_compact_chunk_inputs(
&self,
ctx: &mut Context<F>,
range: &impl RangeInstructions<F>,
chunk_inputs: &[PoseidonCompactChunkInput<F, RATE>],
) -> Vec<PoseidonCompactOutput<F>>
where
F: BigPrimeField,
{
let zero_witness = ctx.load_zero();
let mut outputs = Vec::with_capacity(chunk_inputs.len());
let mut state = self.init_state().clone();
for chunk_input in chunk_inputs {
let is_final = chunk_input.is_final;
for absorb in &chunk_input.inputs {
state.permutation(ctx, range.gate(), absorb, None, &self.spec);
}
// Because the length of each absorb is always RATE. An extra permutation is needed for squeeze.
let mut output_state = state.clone();
output_state.permutation(ctx, range.gate(), &[], None, &self.spec);
let hash =
range.gate().select(ctx, output_state.s[1], zero_witness, *is_final.as_ref());
outputs.push(PoseidonCompactOutput { hash, is_final });
// Reset state to init_state if this is the end of a logical input.
state.select(ctx, range.gate(), is_final, self.init_state());
}
outputs
}
}

/// Poseidon sponge. This is stateful.
Expand Down
123 changes: 122 additions & 1 deletion halo2-base/src/poseidon/hasher/tests/hasher.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use crate::{
gates::{range::RangeInstructions, RangeChip},
halo2_proofs::halo2curves::bn256::Fr,
poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonCompactInput, PoseidonHasher},
poseidon::hasher::{
spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactInput,
PoseidonHasher,
},
safe_types::SafeTypeChip,
utils::{testing::base_test, ScalarField},
Context,
};
use halo2_proofs_axiom::arithmetic::Field;
use itertools::Itertools;
use pse_poseidon::Poseidon;
use rand::Rng;

Expand Down Expand Up @@ -111,6 +115,61 @@ fn hasher_compact_inputs_compatiblity_verification<
}
}

// check if the results from hasher and native sponge are same for hash_compact_input.
fn hasher_compact_chunk_inputs_compatiblity_verification<
const T: usize,
const RATE: usize,
const R_F: usize,
const R_P: usize,
>(
payloads: Vec<(Payload<Fr>, bool)>,
ctx: &mut Context<Fr>,
range: &RangeChip<Fr>,
) {
// Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0.
let spec = OptimizedPoseidonSpec::<Fr, T, RATE>::new::<R_F, R_P, 0>();
let mut hasher = PoseidonHasher::<Fr, T, RATE>::new(spec);
hasher.initialize_consts(ctx, range.gate());

let mut native_results = Vec::with_capacity(payloads.len());
let mut chunk_inputs = Vec::<PoseidonCompactChunkInput<Fr, RATE>>::new();
let true_witness = SafeTypeChip::unsafe_to_bool(ctx.load_constant(Fr::ONE));
let false_witness = SafeTypeChip::unsafe_to_bool(ctx.load_zero());

// Construct native Poseidon sponge.
let mut native_sponge = Poseidon::<Fr, T, RATE>::new(R_F, R_P);
for (payload, is_final) in payloads {
assert!(payload.values.len() == payload.len);
assert!(payload.values.len() % RATE == 0);
let inputs = ctx.assign_witnesses(payload.values.clone());

let is_final_witness = if is_final { true_witness } else { false_witness };
chunk_inputs.push(PoseidonCompactChunkInput {
inputs: inputs.chunks(RATE).map(|c| c.try_into().unwrap()).collect_vec(),
is_final: is_final_witness,
});
native_sponge.update(&payload.values);
if is_final {
let native_result = native_sponge.squeeze();
native_results.push(native_result);
native_sponge = Poseidon::<Fr, T, RATE>::new(R_F, R_P);
}
}
let compact_outputs = hasher.hash_compact_chunk_inputs(ctx, range, &chunk_inputs);
assert_eq!(chunk_inputs.len(), compact_outputs.len());
let mut output_offset = 0;
for (compact_output, chunk_input) in compact_outputs.iter().zip(chunk_inputs) {
// into() doesn't work if ! is in the beginning in the bool expression...
let is_final_input = chunk_input.is_final.as_ref().value();
let is_final_output = compact_output.is_final().as_ref().value();
assert_eq!(is_final_input, is_final_output);
if is_final_output == &Fr::ONE {
assert_eq!(native_results[output_offset], *compact_output.hash().value());
output_offset += 1;
}
}
}

fn random_payload<F: ScalarField>(max_len: usize, len: usize, max_value: usize) -> Payload<F> {
assert!(len <= max_len);
let mut rng = rand::thread_rng();
Expand Down Expand Up @@ -235,3 +294,65 @@ fn test_poseidon_hasher_compact_inputs_with_prover() {
});
}
}

#[test]
fn test_poseidon_hasher_compact_chunk_inputs() {
{
const T: usize = 3;
const RATE: usize = 2;
let payloads = vec![
(random_payload(RATE * 5, RATE * 5, usize::MAX), true),
(random_payload(RATE, RATE, usize::MAX), false),
(random_payload(RATE * 2, RATE * 2, usize::MAX), true),
(random_payload(RATE * 3, RATE * 3, usize::MAX), true),
];
base_test().k(12).run(|ctx, range| {
hasher_compact_chunk_inputs_compatiblity_verification::<T, RATE, 8, 57>(
payloads, ctx, range,
);
});
}
{
const T: usize = 3;
const RATE: usize = 2;
let payloads = vec![
(random_payload(0, 0, usize::MAX), true),
(random_payload(0, 0, usize::MAX), false),
(random_payload(0, 0, usize::MAX), false),
];
base_test().k(12).run(|ctx, range| {
hasher_compact_chunk_inputs_compatiblity_verification::<T, RATE, 8, 57>(
payloads, ctx, range,
);
});
}
}

#[test]
fn test_poseidon_hasher_compact_chunk_inputs_with_prover() {
{
const T: usize = 3;
const RATE: usize = 2;
let params = [
(RATE, false),
(RATE * 2, false),
(RATE * 5, false),
(RATE * 2, true),
(RATE * 5, true),
];
let init_payloads = params
.iter()
.map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final))
.collect::<Vec<_>>();
let logic_payloads = params
.iter()
.map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final))
.collect::<Vec<_>>();
base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| {
let ctx = pool.main();
hasher_compact_chunk_inputs_compatiblity_verification::<T, RATE, 8, 57>(
input, ctx, range,
);
});
}
}
58 changes: 56 additions & 2 deletions halo2-base/src/safe_types/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ impl<F: ScalarField, const MAX_LEN: usize> VarLenBytes<F, MAX_LEN> {
padded.into_iter().map(|b| SafeByte(b)).collect::<Vec<_>>().try_into().unwrap(),
)
}

/// Return a copy of the byte array with 0 padding ensured.
pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
Self::new(bytes.try_into().unwrap(), self.len)
}
}

/// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time.
Expand Down Expand Up @@ -93,7 +99,13 @@ impl<F: ScalarField> VarLenBytesVec<F> {
gate: &impl GateInstructions<F>,
) -> FixLenBytesVec<F> {
let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len());
padded.into_iter().map(|b| SafeByte(b)).collect()
FixLenBytesVec::new(padded.into_iter().map(|b| SafeByte(b)).collect_vec(), self.max_len())
}

/// Return a copy of the byte array with 0 padding ensured.
pub fn ensure_0_padding(&self, ctx: &mut Context<F>, gate: &impl GateInstructions<F>) -> Self {
let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len);
Self::new(bytes, self.len, self.max_len())
}
}

Expand All @@ -117,6 +129,27 @@ impl<F: ScalarField, const LEN: usize> FixLenBytes<F, LEN> {
}
}

/// Represents a fixed length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time.
#[derive(Debug, Clone, Getters)]
pub struct FixLenBytesVec<F: ScalarField> {
/// The byte array
#[getset(get = "pub")]
bytes: Vec<SafeByte<F>>,
}

impl<F: ScalarField> FixLenBytesVec<F> {
// FixLenBytes can be only created by SafeChip.
pub(super) fn new(bytes: Vec<SafeByte<F>>, len: usize) -> Self {
assert_eq!(bytes.len(), len, "bytes length doesn't match");
Self { bytes }
}

/// Returns the length of the byte array.
pub fn len(&self) -> usize {
self.bytes.len()
}
}

impl<F: ScalarField, const TOTAL_BITS: usize> From<SafeType<F, 1, TOTAL_BITS>>
for FixLenBytes<F, { SafeType::<F, 1, TOTAL_BITS>::VALUE_LENGTH }>
{
Expand All @@ -138,7 +171,7 @@ impl<F: ScalarField, const TOTAL_BITS: usize>

/// Represents a fixed length byte array in circuit as a vector, where length must be fixed.
/// Not encouraged to use because `LEN` cannot be verified at compile time.
pub type FixLenBytesVec<F> = Vec<SafeByte<F>>;
// pub type FixLenBytesVec<F> = Vec<SafeByte<F>>;

/// Takes a fixed length array `arr` and returns a length `out_len` array equal to
/// `[[0; out_len - len], arr[..len]].concat()`, i.e., we take `arr[..len]` and
Expand Down Expand Up @@ -172,3 +205,24 @@ pub fn left_pad_var_array_to_fixed<F: ScalarField>(
}
padded
}

fn ensure_0_padding<F: ScalarField>(
ctx: &mut Context<F>,
gate: &impl GateInstructions<F>,
bytes: &[SafeByte<F>],
len: AssignedValue<F>,
) -> Vec<SafeByte<F>> {
let max_len = bytes.len();
// Generate a mask array where a[i] = i < len for i = 0..max_len.
let idx = gate.dec(ctx, len);
let len_indicator = gate.idx_to_indicator(ctx, idx, max_len);
// inputs_mask[i] = sum(len_indicator[i..])
let mut mask = gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec();
mask.reverse();

bytes
.iter()
.zip(mask.iter())
.map(|(byte, mask)| SafeByte(gate.mul(ctx, byte.0, *mask)))
.collect_vec()
}
31 changes: 30 additions & 1 deletion halo2-base/src/safe_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
FixLenBytes::<F, MAX_LEN>::new(inputs.map(|input| Self::unsafe_to_byte(input)))
}

/// Unsafe method that directly converts `inputs` to [`FixLenBytesVec`] **without any checks**.
/// This should **only** be used if an external library needs to convert their types to [`SafeByte`].
pub fn unsafe_to_fix_len_bytes_vec(
inputs: RawAssignedValues<F>,
len: usize,
) -> FixLenBytesVec<F> {
FixLenBytesVec::<F>::new(
inputs.into_iter().map(|input| Self::unsafe_to_byte(input)).collect_vec(),
len,
)
}

/// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes.
///
/// * ctx: Circuit [Context]<F> to assign witnesses to.
Expand All @@ -249,7 +261,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
/// * ctx: Circuit [Context]<F> to assign witnesses to.
/// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding.
/// * len: [AssignedValue]<F> witness representing the variable length of the byte array. Constrained to be `<= max_len`.
/// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain.
/// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. We enforce this to be provided explictly to make sure length of `inputs` is determinstic.
pub fn raw_to_var_len_bytes_vec(
&self,
ctx: &mut Context<F>,
Expand Down Expand Up @@ -278,6 +290,23 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> {
FixLenBytes::<F, LEN>::new(inputs.map(|input| self.assert_byte(ctx, input)))
}

/// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytesVec.
///
/// * ctx: Circuit [Context]<F> to assign witnesses to.
/// * inputs: Slice representing the byte array.
/// * len: length of the byte array. We enforce this to be provided explictly to make sure length of `inputs` is determinstic.
pub fn raw_to_fix_len_bytes_vec(
&self,
ctx: &mut Context<F>,
inputs: RawAssignedValues<F>,
len: usize,
) -> FixLenBytesVec<F> {
FixLenBytesVec::<F>::new(
inputs.into_iter().map(|input| self.assert_byte(ctx, input)).collect_vec(),
len,
)
}

fn add_bytes_constraints(
&self,
ctx: &mut Context<F>,
Expand Down
29 changes: 27 additions & 2 deletions halo2-base/src/safe_types/tests/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn left_pad_var_len_bytes(mut bytes: Vec<u8>, max_len: usize) -> Vec<u8> {
let len = ctx.load_witness(Fr::from(len as u64));
let bytes = safe.raw_to_var_len_bytes_vec(ctx, bytes, len, max_len);
let padded = bytes.left_pad_to_fixed(ctx, range.gate());
padded.iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect()
padded.bytes().iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect()
})
}

Expand Down Expand Up @@ -132,7 +132,7 @@ fn neg_var_len_bytes_vec_len_less_than_max_len() {

// Circuit Satisfied for valid inputs
#[test]
fn pos_fix_len_bytes_vec() {
fn pos_fix_len_bytes() {
base_test().k(10).lookup_bits(8).run(|ctx, range| {
let safe = SafeTypeChip::new(range);
let fake_bytes = ctx.assign_witnesses(
Expand All @@ -142,6 +142,31 @@ fn pos_fix_len_bytes_vec() {
});
}

// Assert inputs.len() == len
#[test]
#[should_panic]
fn neg_fix_len_bytes_vec() {
base_test().k(10).lookup_bits(8).run(|ctx, range| {
let safe = SafeTypeChip::new(range);
let fake_bytes = ctx.assign_witnesses(
vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::<Vec<_>>(),
);
safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 5);
});
}

// Circuit Satisfied for valid inputs
#[test]
fn pos_fix_len_bytes_vec() {
base_test().k(10).lookup_bits(8).run(|ctx, range| {
let safe = SafeTypeChip::new(range);
let fake_bytes = ctx.assign_witnesses(
vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::<Vec<_>>(),
);
safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 4);
});
}

// =========== Prover ===========
#[test]
fn pos_prover_satisfied() {
Expand Down
Loading