Skip to content

Commit

Permalink
feature(merkle-chunk): conditional select on boolean & working example
Browse files Browse the repository at this point in the history
  • Loading branch information
tchataigner committed Feb 29, 2024
1 parent 90a9b41 commit e59ac2d
Show file tree
Hide file tree
Showing 8 changed files with 94 additions and 111 deletions.
1 change: 0 additions & 1 deletion crates/chunk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ bellpepper-merkle-inclusion = { path = "../merkle-inclusion" }
bincode = "1.3.3"
bitvec = "1.0.1"
halo2curves = { version = "0.6.0", features = ["bits", "derive_serde"] }
itertools = "0.12.1"
paste = "1.0.14"
sha3 = "0.11.0-pre.3"
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
Expand Down
117 changes: 23 additions & 94 deletions crates/chunk/examples/chunk_merkle_proving.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,17 @@ use arecibo::supernova::{
};
use arecibo::traits::snark::default_ck_hint;
use arecibo::traits::{CurveCycleEquipped, Dual, Engine};
use bellpepper::gadgets::boolean::{field_into_boolean_vec_le, Boolean};
use bellpepper::gadgets::boolean::Boolean;
use bellpepper::gadgets::multipack::{bytes_to_bits_le, compute_multipacking, pack_bits};
use bellpepper::gadgets::num::AllocatedNum;
use bellpepper::gadgets::Assignment;
use bellpepper_chunk::traits::{ChunkCircuitInner, ChunkStepCircuit};
use bellpepper_chunk::{FoldStep, InnerCircuit};
use bellpepper_core::boolean::{field_into_allocated_bits_le, AllocatedBit};
use bellpepper_core::{ConstraintSystem, SynthesisError};
use bellpepper_keccak::sha3;
use bellpepper_merkle_inclusion::traits::GadgetDigest;
use bellpepper_merkle_inclusion::{create_gadget_digest_impl, hash_equality};
use bitvec::order::Lsb0;
use bitvec::prelude::BitVec;
use bellpepper_merkle_inclusion::{conditional_hash, create_gadget_digest_impl, hash_equality};
use ff::{Field, PrimeField, PrimeFieldBits};
use halo2curves::bn256::Bn256;
use itertools::Itertools;
use sha3::digest::Output;
use sha3::{Digest, Sha3_256};
use std::marker::PhantomData;
Expand Down Expand Up @@ -57,7 +52,7 @@ fn reconstruct_hash<F: PrimeField + PrimeFieldBits, CS: ConstraintSystem<F>>(
) -> Vec<Boolean> {
// Compute the bit sizes of the field elements
let mut scalar_bit_sizes: Vec<usize> = (0..bit_size / F::CAPACITY as usize)
.map(|i| F::CAPACITY as usize)
.map(|_| F::CAPACITY as usize)
.collect();
// If the bit size is not a multiple of 253, we need to add the remaining bits
if bit_size % F::CAPACITY as usize != 0 {
Expand Down Expand Up @@ -86,48 +81,6 @@ fn reconstruct_hash<F: PrimeField + PrimeFieldBits, CS: ConstraintSystem<F>>(
result
}

pub fn conditionally_select<F: PrimeField, CS: ConstraintSystem<F>>(
mut cs: CS,
a: &AllocatedNum<F>,
b: &AllocatedNum<F>,
condition: &Boolean,
) -> Result<AllocatedNum<F>, SynthesisError> {
let c = AllocatedNum::alloc(cs.namespace(|| "conditional select result"), || {
if *condition.get_value().get()? {
Ok(*a.get_value().get()?)
} else {
Ok(*b.get_value().get()?)
}
})?;

// a * condition + b*(1-condition) = c ->
// a * condition - b*condition = c - b
cs.enforce(
|| "conditional select constraint",
|lc| lc + a.get_variable() - b.get_variable(),
|_| condition.lc(CS::one(), F::ONE),
|lc| lc + c.get_variable() - b.get_variable(),
);

Ok(c)
}

/// If condition return a otherwise b
pub fn conditionally_select_vec<F: PrimeField, CS: ConstraintSystem<F>>(
mut cs: CS,
a: &[AllocatedNum<F>],
b: &[AllocatedNum<F>],
condition: &Boolean,
) -> Result<Vec<AllocatedNum<F>>, SynthesisError> {
a.iter()
.zip_eq(b.iter())
.enumerate()
.map(|(i, (a, b))| {
conditionally_select(cs.namespace(|| format!("select_{i}")), a, b, condition)
})
.collect::<Result<Vec<AllocatedNum<F>>, SynthesisError>>()
}

/*****************************************
* Circuit
*****************************************/
Expand Down Expand Up @@ -193,14 +146,13 @@ impl<E1: CurveCycleEquipped, C: ChunkStepCircuit<E1::Scalar>, const N: usize> No
}

fn primary_circuit(&self, circuit_index: usize) -> Self::C1 {
match circuit_index {
2 => Self::C1::CheckEquality(EqualityCircuit::new()),
_ => {
if let Some(fold_step) = self.inner.circuits().get(circuit_index) {
return Self::C1::IterStep(FoldStepWrapper::new(fold_step.clone()));
}
panic!("No circuit found for index {}", circuit_index)
if circuit_index == 2 {
Self::C1::CheckEquality(EqualityCircuit::new())
} else {
if let Some(fold_step) = self.inner.circuits().get(circuit_index) {
return Self::C1::IterStep(FoldStepWrapper::new(fold_step.clone()));
}
panic!("No circuit found for index {}", circuit_index)
}
}

Expand Down Expand Up @@ -243,32 +195,24 @@ impl<F: PrimeField + PrimeFieldBits> ChunkStepCircuit<F> for ChunkStep<F> {
.to_bits_le(&mut cs.namespace(|| "get positional bit"))
.unwrap()[0];

let hash_order = conditionally_select_vec(
&mut cs.namespace(|| "conditional ordering"),
&[&chunk_in[1..3], &z[0..2]].concat(),
&[&z[0..2], &chunk_in[1..3]].concat(),
boolean,
)?;

let mut first_hash = reconstruct_hash(
&mut cs.namespace(|| "reconstruct acc hash"),
&hash_order[0..2],
256,
);
let acc = reconstruct_hash(&mut cs.namespace(|| "reconstruct acc hash"), &z[0..2], 256);

let mut second_hash = reconstruct_hash(
let sibling = reconstruct_hash(
&mut cs.namespace(|| "reconstruct_sibling_hash"),
&hash_order[2..],
&chunk_in[1..3],
256,
);
first_hash.append(&mut second_hash);

let new_acc = sha3(&mut cs.namespace(|| "hash new acc"), &first_hash)?;
let new_acc = conditional_hash::<_, _, Sha3>(
&mut cs.namespace(|| "conditional_hash"),
&acc,
&sibling,
boolean,
)?;

let new_acc_f_1 = pack_bits(&mut cs.namespace(|| "pack_bits new_acc 1"), &new_acc[..253])?;
let new_acc_f_2 = pack_bits(&mut cs.namespace(|| "pack_bits new_acc 2"), &new_acc[253..])?;
dbg!(&new_acc_f_1);
dbg!(&new_acc_f_2);

let z_out = vec![new_acc_f_1, new_acc_f_2, z[2].clone(), z[3].clone()];

Ok(z_out)
Expand Down Expand Up @@ -378,7 +322,7 @@ fn main() {
// Leaf and root hashes
let a_leaf_hash =
hash::<<Sha3 as GadgetDigest<<E1 as Engine>::Scalar>>::OutOfCircuitHasher>("a".as_bytes());
let mut b_leaf_hash =
let b_leaf_hash =
hash::<<Sha3 as GadgetDigest<<E1 as Engine>::Scalar>>::OutOfCircuitHasher>("b".as_bytes());
let c_leaf_hash =
hash::<<Sha3 as GadgetDigest<<E1 as Engine>::Scalar>>::OutOfCircuitHasher>("c".as_bytes());
Expand All @@ -388,9 +332,6 @@ fn main() {
let ab_leaf_hash = hash::<<Sha3 as GadgetDigest<<E1 as Engine>::Scalar>>::OutOfCircuitHasher>(
&[a_leaf_hash, b_leaf_hash].concat(),
);
dbg!(compute_multipacking::<<E1 as Engine>::Scalar>(
&bytes_to_bits_le(&ab_leaf_hash)
));
let cd_leaf_hash = hash::<<Sha3 as GadgetDigest<<E1 as Engine>::Scalar>>::OutOfCircuitHasher>(
&[c_leaf_hash, d_leaf_hash].concat(),
);
Expand All @@ -400,10 +341,9 @@ fn main() {
);

// Intermediate hashes
let mut intermediate_hashes: Vec<<E1 as Engine>::Scalar> = vec![a_leaf_hash, cd_leaf_hash]
let intermediate_hashes: Vec<<E1 as Engine>::Scalar> = [a_leaf_hash, cd_leaf_hash]
.iter()
.map(|h| compute_multipacking::<<E1 as Engine>::Scalar>(&bytes_to_bits_le(h)))
.flatten()
.flat_map(|h| compute_multipacking::<<E1 as Engine>::Scalar>(&bytes_to_bits_le(h)))
.collect();
let mut intermediate_key_hashes = vec![<E1 as Engine>::Scalar::ONE];
intermediate_key_hashes.append(&mut intermediate_hashes[0..2].to_vec());
Expand All @@ -420,7 +360,7 @@ fn main() {
// Multipacking the leaf and root hashes
let mut z0_primary =
compute_multipacking::<<E1 as Engine>::Scalar>(&bytes_to_bits_le(&b_leaf_hash));
let mut root_fields =
let root_fields =
compute_multipacking::<<E1 as Engine>::Scalar>(&bytes_to_bits_le(&abcd_leaf_hash));

// The accumulator elements are initialized to 0
Expand Down Expand Up @@ -460,17 +400,6 @@ fn main() {
res.is_ok(),
start.elapsed()
);

let start = Instant::now();

let res = recursive_snark.verify(&pp, &z0_primary, &z0_secondary);
assert!(res.is_ok());
println!(
"RecursiveSNARK::verify {}: {:?}, took {:?} ",
step,
res.is_ok(),
start.elapsed()
);
}

println!("Generating a CompressedSNARK...");
Expand Down
1 change: 1 addition & 0 deletions crates/chunk/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub trait ChunkCircuitInner<F: PrimeField, C: ChunkStepCircuit<F>, const N: usiz
/// `new` must return a new instance of the chunk circuit.
/// # Arguments
/// * `intermediate_steps_input` - The intermediate input values for each of the step circuits.
/// * `post_processing_circuit` - The post processing circuit to be used after the loop of steps.
///
/// # Note
///
Expand Down
2 changes: 1 addition & 1 deletion crates/chunk/tests/gadget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ fn verify_chunk_circuit<F: PrimeField, C: ChunkStepCircuit<F>, const N: usize>()

let expected = (test_inputs.len() / N) + if test_inputs.len() % N != 0 { 2 } else { 1 };

let circuit = InnerCircuit::<F, C, N>::new(&test_inputs).unwrap();
let circuit = InnerCircuit::<F, C, N>::new(&test_inputs, None).unwrap();

let actual = circuit.num_fold_steps();

Expand Down
2 changes: 2 additions & 0 deletions crates/merkle-inclusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ rust-version = "1.71.1"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
bellpepper = { workspace = true }
bellpepper-core = { workspace = true }
digest = "0.11.0-pre.4"
ff = { workspace = true }
itertools = "0.12.1"

[dev-dependencies]
bellpepper-keccak = { path="../keccak", version = "0.1.0" }
Expand Down
19 changes: 8 additions & 11 deletions crates/merkle-inclusion/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pub mod traits;
mod utils;

use crate::traits::GadgetDigest;
use crate::utils::conditionally_select_vec;
use bellpepper_core::boolean::Boolean;
use bellpepper_core::{ConstraintSystem, SynthesisError};
use ff::PrimeField;
Expand Down Expand Up @@ -126,17 +128,12 @@ where
GD: GadgetDigest<E>,
{
// Determine the order of hashing based on the bit value.
let hash_order: Vec<Boolean> = if bit.get_value() == Some(true) {
vec![sibling.to_owned(), acc.to_vec()]
.into_iter()
.flatten()
.collect()
} else {
vec![acc.to_vec(), sibling.to_owned()]
.into_iter()
.flatten()
.collect()
};
let hash_order: Vec<Boolean> = conditionally_select_vec(
&mut cs.namespace(|| "hash order"),
&[sibling, acc].concat(),
&[acc, sibling].concat(),
bit,
)?;

// Compute the new hash.
let new_acc = GD::digest(&mut cs.namespace(|| "digest leaf & sibling"), &hash_order)?;
Expand Down
48 changes: 48 additions & 0 deletions crates/merkle-inclusion/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use bellpepper_core::boolean::{AllocatedBit, Boolean};
use bellpepper_core::{ConstraintSystem, SynthesisError};
use ff::PrimeField;
use itertools::Itertools;
use std::ops::Sub;

pub fn conditionally_select_bool<F: PrimeField, CS: ConstraintSystem<F>>(
mut cs: CS,
a: &Boolean,
b: &Boolean,
condition: &Boolean,
) -> Result<Boolean, SynthesisError> {
let value = if condition.get_value().unwrap_or_default() {
a.get_value()
} else {
b.get_value()
};

let result = Boolean::Is(AllocatedBit::alloc(
&mut cs.namespace(|| "conditional select result"),
value,
)?);

cs.enforce(
|| "conditional select constraint",
|_| condition.lc(CS::one(), F::ONE),
|_| a.lc(CS::one(), F::ONE).sub(&b.lc(CS::one(), F::ONE)),
|_| result.lc(CS::one(), F::ONE).sub(&b.lc(CS::one(), F::ONE)),
);

Ok(result)
}

/// If condition return a otherwise b
pub fn conditionally_select_vec<F: PrimeField, CS: ConstraintSystem<F>>(
mut cs: CS,
a: &[Boolean],
b: &[Boolean],
condition: &Boolean,
) -> Result<Vec<Boolean>, SynthesisError> {
a.iter()
.zip_eq(b.iter())
.enumerate()
.map(|(i, (a, b))| {
conditionally_select_bool(cs.namespace(|| format!("select_{i}")), a, b, condition)
})
.collect::<Result<Vec<Boolean>, SynthesisError>>()
}
15 changes: 11 additions & 4 deletions crates/merkle-inclusion/tests/gadget.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ fn expected_circuit_constraints<GD: GadgetDigest<Scalar>>(
hasher_constraints: usize,
nbr_siblings: usize,
) -> usize {
3 * GD::output_size() * 8 + (GD::output_size() * 8 + hasher_constraints + 1) * nbr_siblings
3 * GD::output_size() * 8
+ (GD::output_size() * 8 + hasher_constraints + 1) * nbr_siblings
+ nbr_siblings * 4 * 8 * GD::output_size()
}

fn verify_inclusion_merkle<GD: GadgetDigest<Scalar>, O: BitOrder>() {
Expand Down Expand Up @@ -105,11 +107,16 @@ fn verify_incorrect_sibling_hashes<GD: GadgetDigest<Scalar>, O: BitOrder>() {
],
);

let cs = TestConstraintSystem::<<Bls12 as Engine>::Fr>::new();
let res = verify_proof::<_, _, GD>(cs, &bytes_to_bitvec::<O>(simple_tree.root()), &proof);
let mut cs = TestConstraintSystem::<<Bls12 as Engine>::Fr>::new();
verify_proof::<_, _, GD>(
&mut cs.namespace(|| "verify proof"),
&bytes_to_bitvec::<O>(simple_tree.root()),
&proof,
)
.expect("verify_proof should end with Ok");

assert!(
res.is_err(),
!cs.is_satisfied(),
"Proof verification should fail with incorrect sibling hashes."
);
}
Expand Down

0 comments on commit e59ac2d

Please sign in to comment.