From aebcefee5cc13f33caf115c329986c23aacbfbe7 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Thu, 20 Oct 2022 11:29:50 -0700 Subject: [PATCH 01/28] chore: add display feature to Cargo.toml --- Cargo.toml | 5 +++-- configs/verify_circuit.config | 2 +- examples/evm-verifier-with-accumulator.rs | 4 ++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 747050f9..6df8b804 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,10 +43,11 @@ zkevm_circuit_benchmarks = {git = "https://github.com/privacy-scaling-exploratio zkevm_circuits = {git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", package = "zkevm-circuits" } [features] -default = ["loader_evm", "loader_halo2", "system_halo2"] +default = ["loader_evm", "loader_halo2", "system_halo2", "display"] loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:foundry_evm"] -loader_halo2 = ["dep:halo2_proofs", "dep:halo2_base", "dep:halo2_ecc", "dep:poseidon"] +loader_halo2 = ["dep:halo2_proofs", "dep:halo2_base", "halo2_ecc", "dep:poseidon"] system_halo2 = ["dep:halo2_proofs"] +display = ["halo2_ecc/display"] sanity_check = [] [patch."https://github.com/privacy-scaling-explorations/halo2"] diff --git a/configs/verify_circuit.config b/configs/verify_circuit.config index 37e349a8..6f70df6a 100644 --- a/configs/verify_circuit.config +++ b/configs/verify_circuit.config @@ -1 +1 @@ -{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3} \ No newline at end of file +{"strategy":"Simple","degree":23,"num_advice":7,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 32e25a53..77d58f26 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -179,7 +179,7 @@ fn gen_proof< MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); // For testing purposes: Native verify // Uncomment to test if evm verifier fails silently - { + /*{ let proof = { let mut transcript = Blake2bWrite::init(Vec::new()); create_proof::< @@ -213,7 +213,7 @@ fn gen_proof< let instances = &[instances[0].to_vec()]; let proof = Plonk::read_proof(&svk, &protocol, instances, &mut transcript).unwrap(); assert!(Plonk::verify(&svk, &dk, &protocol, instances, &proof).unwrap()); - } + }*/ let instances = instances.iter().map(|instances| instances.as_slice()).collect_vec(); let proof = { From 5e882920f1e363b269ecd6cad9d59163db934402 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Thu, 20 Oct 2022 12:37:11 -0700 Subject: [PATCH 02/28] fix: change test data directory to ./data and give error message --- src/system/halo2/test/kzg/halo2.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/system/halo2/test/kzg/halo2.rs b/src/system/halo2/test/kzg/halo2.rs index e332b036..e0cf3339 100644 --- a/src/system/halo2/test/kzg/halo2.rs +++ b/src/system/halo2/test/kzg/halo2.rs @@ -431,7 +431,7 @@ pub fn create_snark() -> (ParamsKZG, Snark) { // TODO: need to cache the instances as well! let proof = { - let path = format!("./src/system/halo2/test/data/proof_{}.data", T::NAME); + let path = format!("./data/proof_{}.data", T::NAME); match std::fs::File::open(path.as_str()) { Ok(mut file) => { let mut buf = vec![]; @@ -451,7 +451,8 @@ pub fn create_snark() -> (ParamsKZG, Snark) { ) .unwrap(); let proof = transcript.finalize(); - let mut file = std::fs::File::create(path.as_str()).unwrap(); + let mut file = std::fs::File::create(path.as_str()) + .expect(format!("{:?} should exist", path).as_str()); file.write_all(&proof).unwrap(); proof } From b0d0910fd5421c56b69365ea6ab44b03179c1708 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Thu, 20 Oct 2022 22:32:02 -0700 Subject: [PATCH 03/28] feat: add succinct_verify_or_dummy to allow flag to turn off verification in aggregation snark --- configs/verify_circuit.config | 2 +- examples/evm-verifier-with-accumulator.rs | 3 +- src/loader.rs | 9 ++ src/loader/halo2/loader.rs | 25 +++- src/pcs.rs | 19 ++- src/pcs/kzg/multiopen/bdfg21.rs | 162 +++++++++++----------- src/util/msm.rs | 27 +--- src/verifier.rs | 8 ++ src/verifier/plonk.rs | 41 ++++++ 9 files changed, 188 insertions(+), 108 deletions(-) diff --git a/configs/verify_circuit.config b/configs/verify_circuit.config index 6f70df6a..1b34fdc7 100644 --- a/configs/verify_circuit.config +++ b/configs/verify_circuit.config @@ -1 +1 @@ -{"strategy":"Simple","degree":23,"num_advice":7,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3} \ No newline at end of file +{"strategy":"Simple","degree":20,"num_advice":7,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 77d58f26..1bf58f28 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -342,6 +342,7 @@ fn main() { end_timer!(deploy_time); fs::write("./data/verifier_bytecode.dat", hex::encode(&deployment_code)).unwrap(); + /* // use different input snarks to test instances etc let app_circuit = StandardPlonk::rand(OsRng); let snark = create_snark_shplonk::( @@ -351,7 +352,7 @@ fn main() { None, ); let snarks = vec![snark]; - let agg_circuit = AggregationCircuit::new(¶ms, snarks, true); + let agg_circuit = AggregationCircuit::new(¶ms, snarks, true); */ let proof_time = start_timer!(|| "create agg_circuit proof"); let proof = gen_proof::<_, _, EvmTranscript, EvmTranscript>( ¶ms, diff --git a/src/loader.rs b/src/loader.rs index 2308b08b..f8f0ad77 100644 --- a/src/loader.rs +++ b/src/loader.rs @@ -209,6 +209,15 @@ pub trait ScalarLoader { pub trait Loader: EcPointLoader + ScalarLoader + Clone + Debug { + fn ec_point_select( + &self, + _a: &Self::LoadedEcPoint, + _b: &Self::LoadedEcPoint, + _sel: &Self::LoadedScalar, + ) -> Result { + todo!() + } + fn start_cost_metering(&self, _: &str) {} fn end_cost_metering(&self) {} diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index 003f0811..04b5a2f4 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -875,6 +875,10 @@ impl<'a, 'b, C: CurveAffine> EcPointLoader for Rc> { self.assign_const_ec_point(*ec_point) } + fn ec_point_load_one(&self) -> Self::LoadedEcPoint { + self.ec_point(self.assign_const_ec_point(C::generator()).assigned()) + } + fn ec_point_assert_eq( &self, annotation: &str, @@ -897,4 +901,23 @@ impl<'a, 'b, C: CurveAffine> EcPointLoader for Rc> { } } -impl<'a, 'b, C: CurveAffine> Loader for Rc> {} +impl<'a, 'b, C: CurveAffine> Loader for Rc> { + // only using this when `sel = use_dummy` and `a` is dummy + // you should never use dummy if `b` is constant + fn ec_point_select( + &self, + a: &Self::LoadedEcPoint, + b: &Self::LoadedEcPoint, + sel: &Self::LoadedScalar, + ) -> Result { + if matches!(b.value, Value::Constant(_)) { + return Ok(b.clone()); + } + let a = a.assigned(); + let b = b.assigned(); + let sel = sel.assigned(); + let assigned = halo2_ecc::ecc::select(self.field_chip(), &mut self.ctx_mut(), &a, &b, &sel) + .expect("ec_point_select should not fail"); + Ok(self.ec_point(assigned)) + } +} diff --git a/src/pcs.rs b/src/pcs.rs index 65804895..97ce8ece 100644 --- a/src/pcs.rs +++ b/src/pcs.rs @@ -29,11 +29,7 @@ pub struct Query { impl Query { pub fn with_evaluation(self, eval: T) -> Query { - Query { - poly: self.poly, - shift: self.shift, - eval, - } + Query { poly: self.poly, shift: self.shift, eval } } } @@ -60,6 +56,19 @@ where queries: &[Query], proof: &Self::Proof, ) -> Result; + + // same as succinct_verify except `use_dummy` is boolean loaded scalar + // if `use_dummy` is 1, then put in dummy values to MSM so constraints are satisfies regardless of `proof` values + fn succinct_verify_or_dummy( + _svk: &Self::SuccinctVerifyingKey, + _commitments: &[Msm], + _point: &L::LoadedScalar, + _queries: &[Query], + _proof: &Self::Proof, + _use_dummy: &L::LoadedScalar, + ) -> Result { + todo!() + } } pub trait Decider: PolynomialCommitmentScheme diff --git a/src/pcs/kzg/multiopen/bdfg21.rs b/src/pcs/kzg/multiopen/bdfg21.rs index fb012dea..fcc71f8c 100644 --- a/src/pcs/kzg/multiopen/bdfg21.rs +++ b/src/pcs/kzg/multiopen/bdfg21.rs @@ -51,9 +51,8 @@ where let sets = query_sets(queries); let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); - let powers_of_mu = proof - .mu - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let powers_of_mu = + proof.mu.powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); let msms = sets .iter() .zip(coeffs.iter()) @@ -68,10 +67,47 @@ where let rhs = Msm::base(proof.w_prime.clone()); let lhs = f + rhs.clone() * &proof.z_prime; - Ok(KzgAccumulator::new( - lhs.evaluate(Some(svk.g)), - rhs.evaluate(Some(svk.g)), - )) + Ok(KzgAccumulator::new(lhs.evaluate(Some(svk.g)), rhs.evaluate(Some(svk.g)))) + } + + fn succinct_verify_or_dummy( + svk: &KzgSuccinctVerifyingKey, + commitments: &[Msm], + z: &L::LoadedScalar, + queries: &[Query], + proof: &Bdfg21Proof, + use_dummy: &L::LoadedScalar, + ) -> Result { + let f = { + let sets = query_sets(queries); + let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); + + let powers_of_mu = + proof.mu.powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let msms = sets + .iter() + .zip(coeffs.iter()) + .map(|(set, coeff)| set.msm(coeff, commitments, &powers_of_mu)); + + msms.zip(proof.gamma.powers(sets.len()).into_iter()) + .map(|(msm, power_of_gamma)| msm * &power_of_gamma) + .sum::>() + - Msm::base(proof.w.clone()) * &coeffs[0].z_s + }; + + let mut rhs = Msm::base(proof.w_prime.clone()); + let mut lhs = f + rhs.clone() * &proof.z_prime; + + let loader = >::LoadedScalar::loader(z); + let dummy_point = loader.ec_point_load_one(); + for base in rhs.bases.iter_mut() { + *base = loader.ec_point_select(&dummy_point, base, use_dummy)?; + } + for base in lhs.bases.iter_mut() { + *base = loader.ec_point_select(&dummy_point, base, use_dummy)?; + } + + Ok(KzgAccumulator::new(lhs.evaluate(Some(svk.g)), rhs.evaluate(Some(svk.g)))) } } @@ -99,24 +135,14 @@ where let w = transcript.read_ec_point()?; let z_prime = transcript.squeeze_challenge(); let w_prime = transcript.read_ec_point()?; - Ok(Bdfg21Proof { - mu, - gamma, - w, - z_prime, - w_prime, - }) + Ok(Bdfg21Proof { mu, gamma, w, z_prime, w_prime }) } } fn query_sets(queries: &[Query]) -> Vec> { - let poly_shifts = queries.iter().fold( - Vec::<(usize, Vec, Vec<&T>)>::new(), - |mut poly_shifts, query| { - if let Some(pos) = poly_shifts - .iter() - .position(|(poly, _, _)| *poly == query.poly) - { + let poly_shifts = + queries.iter().fold(Vec::<(usize, Vec, Vec<&T>)>::new(), |mut poly_shifts, query| { + if let Some(pos) = poly_shifts.iter().position(|(poly, _, _)| *poly == query.poly) { let (_, shifts, evals) = &mut poly_shifts[pos]; if !shifts.contains(&query.shift) { shifts.push(query.shift); @@ -126,39 +152,35 @@ fn query_sets(queries: &[Query]) -> Vec>::new(), - |mut sets, (poly, shifts, evals)| { - if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) - }) { - let set = &mut sets[pos]; - if !set.polys.contains(&poly) { - set.polys.push(poly); - set.evals.push( - set.shifts - .iter() - .map(|lhs| { - let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); - evals[idx].clone() - }) - .collect(), - ); - } - } else { - let set = QuerySet { - shifts, - polys: vec![poly], - evals: vec![evals.into_iter().cloned().collect()], - }; - sets.push(set); + }); + + poly_shifts.into_iter().fold(Vec::>::new(), |mut sets, (poly, shifts, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.shifts + .iter() + .map(|lhs| { + let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + evals[idx].clone() + }) + .collect(), + ); } - sets - }, - ) + } else { + let set = QuerySet { + shifts, + polys: vec![poly], + evals: vec![evals.into_iter().cloned().collect()], + }; + sets.push(set); + } + sets + }) } fn query_set_coeffs>( @@ -168,25 +190,17 @@ fn query_set_coeffs>( ) -> Vec> { let loader = z.loader(); - let superset = sets - .iter() - .flat_map(|set| set.shifts.clone()) - .sorted() - .dedup(); + let superset = sets.iter().flat_map(|set| set.shifts.clone()).sorted().dedup(); let size = 2.max( - (sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1) - .next_power_of_two() - .ilog2() as usize + (sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1).next_power_of_two().ilog2() + as usize + 1, ); let powers_of_z = z.powers(size); - let z_prime_minus_z_shift_i = BTreeMap::from_iter(superset.map(|shift| { - ( - shift, - z_prime.clone() - z.clone() * loader.load_const(&shift), - ) - })); + let z_prime_minus_z_shift_i = BTreeMap::from_iter( + superset.map(|shift| (shift, z_prime.clone() - z.clone() * loader.load_const(&shift))), + ); let mut z_s_1 = None; let mut coeffs = sets @@ -318,10 +332,7 @@ where .collect_vec(); let z_s = loader.product( - &shifts - .iter() - .map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()) - .collect_vec(), + &shifts.iter().map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()).collect_vec(), ); let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone())); @@ -350,13 +361,8 @@ where .iter_mut() .chain(self.commitment_coeff.as_mut()) .for_each(Fraction::evaluate); - let barycentric_weights_sum = loader.sum( - &self - .eval_coeffs - .iter() - .map(Fraction::evaluated) - .collect_vec(), - ); + let barycentric_weights_sum = + loader.sum(&self.eval_coeffs.iter().map(Fraction::evaluated).collect_vec()); self.r_eval_coeff = Some(match self.commitment_coeff.clone() { Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum), None => Fraction::one_over(barycentric_weights_sum), diff --git a/src/util/msm.rs b/src/util/msm.rs index a7a3d45d..dd805e50 100644 --- a/src/util/msm.rs +++ b/src/util/msm.rs @@ -12,7 +12,7 @@ use std::{ pub struct Msm> { constant: Option, scalars: Vec, - bases: Vec, + pub(crate) bases: Vec, } impl Default for Msm @@ -21,11 +21,7 @@ where L: Loader, { fn default() -> Self { - Self { - constant: None, - scalars: Vec::new(), - bases: Vec::new(), - } + Self { constant: None, scalars: Vec::new(), bases: Vec::new() } } } @@ -35,19 +31,12 @@ where L: Loader, { pub fn constant(constant: L::LoadedScalar) -> Self { - Msm { - constant: Some(constant), - ..Default::default() - } + Msm { constant: Some(constant), ..Default::default() } } pub fn base(base: L::LoadedEcPoint) -> Self { let one = base.loader().load_one(); - Msm { - scalars: vec![one], - bases: vec![base], - ..Default::default() - } + Msm { scalars: vec![one], bases: vec![base], ..Default::default() } } pub(crate) fn size(&self) -> usize { @@ -64,13 +53,7 @@ where } pub fn evaluate(self, gen: Option) -> L::LoadedEcPoint { - let gen = gen.map(|gen| { - self.bases - .first() - .unwrap() - .loader() - .ec_point_load_const(&gen) - }); + let gen = gen.map(|gen| self.bases.first().unwrap().loader().ec_point_load_const(&gen)); L::LoadedEcPoint::multi_scalar_multiplication( iter::empty() .chain(self.constant.map(|constant| (constant, gen.unwrap()))) diff --git a/src/verifier.rs b/src/verifier.rs index 51e05382..9faec6be 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -34,6 +34,14 @@ where proof: &Self::Proof, ) -> Result, Error>; + fn succinct_verify_or_dummy( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + proof: &Self::Proof, + use_dummy: &L::LoadedScalar, + ) -> Result, Error>; + fn verify( svk: &MOS::SuccinctVerifyingKey, dk: &MOS::DecidingKey, diff --git a/src/verifier/plonk.rs b/src/verifier/plonk.rs index 7d653137..8a5cb7fa 100644 --- a/src/verifier/plonk.rs +++ b/src/verifier/plonk.rs @@ -71,6 +71,47 @@ where Ok(accumulators) } + + fn succinct_verify_or_dummy( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + proof: &Self::Proof, + use_dummy: &L::LoadedScalar, + ) -> Result, Error> { + let common_poly_eval = { + let mut common_poly_eval = CommonPolynomialEvaluation::new( + &protocol.domain, + langranges(protocol, instances), + &proof.z, + ); + + L::LoadedScalar::batch_invert(common_poly_eval.denoms()); + common_poly_eval.evaluate(); + + common_poly_eval + }; + + let mut evaluations = proof.evaluations(protocol, instances, &common_poly_eval)?; + let commitments = proof.commitments(protocol, &common_poly_eval, &mut evaluations)?; + let queries = proof.queries(protocol, evaluations); + + let accumulator = MOS::succinct_verify_or_dummy( + svk, + &commitments, + &proof.z, + &queries, + &proof.pcs, + use_dummy, + )?; + + let accumulators = iter::empty() + .chain(Some(accumulator)) + .chain(proof.old_accumulators.iter().cloned()) + .collect(); + + Ok(accumulators) + } } #[derive(Clone, Debug)] From 916b29fe89cb0dafebdf040777b1384587208ab5 Mon Sep 17 00:00:00 2001 From: Han Date: Mon, 24 Oct 2022 03:39:01 -0700 Subject: [PATCH 04/28] Rollback to pse halo2 and halo2wrong for first release (#5) * feat: move `Accumulator` to `accumulator.rs` * feat: update due to halo2 * feat: upgrade to use branch `feature/generic-instructions` of `halo2wrong` * refactor: rollback to `{halo2,halo2_wrong}` without challenge API and cleanup dependencies * chore: rename statement to instance and auxliary to witness * chore: use `finalize` instead of `code` * feat: add `Code::deployment` and `EvmLoader::deployment_code`; add example `evm-verifier-codegen` * fix: typo * feat: reduce generated evm verifier size; rename to `evm-verifier` and add another example `evm-verifier-with-accumulator` * fix: due to `halo2wrong` * feat: reorganize mods and traits * fix: allow empty `values` in `sum_*` and move them under `ScalarLoader` * ci: use `--all-features` for `cargo test` * fix: use same strategy for aggregation testing * fix: simplify trait `PlonkVerifier` again * fix: move system specified transcript under mod `system` * feat: add `quotient_poly` info in `Protocol` * feat: implement linearization for circom integration * feat: re-export loader's dependency for consumer * refactor: for circom's integration * tmp: pin `revm` to rev * fix: remove parentheses * fix: upgrade for multi-phase halo2 * feat: improve error reporting * chore: rename crate to snake case * feat: add `Domain` as an input of `PolynomialCommitmentScheme::read_proof` * refactor: for further integration * feat: generalize to suppoer also ipa and add builder fns to `system::halo2::Config` * feat: add `KzgDecider` for simple evm verifier * refactor: split `AccumulationScheme` and `AccumulatorEncoding` * refactor: split `PolynomialCommitmentScheme` and `MultiOpenScheme` * fix: not need sealed actually * fix: `chunk_size` should be `LIMBS` when recovering accumulator * feat: add `Expression::DistributePowers` to avoid stack overflow * fix: update and pin foundry * fix: move testing circuits under `system/halo2` * fix: allow accumulate single accumulator * feat: remove all patch and make less depending `halo2wrong` --- .github/workflows/ci.yaml | 2 +- .gitignore | 2 +- Cargo.toml | 69 +- examples/evm-verifier-with-accumulator.rs | 619 ++++++++++++ examples/evm-verifier.rs | 260 +++++ rust-toolchain | 2 +- src/cost.rs | 44 + src/lib.rs | 33 +- src/loader.rs | 195 ++-- src/loader/evm.rs | 68 +- src/loader/evm/accumulation.rs | 98 -- src/loader/evm/code.rs | 115 ++- src/loader/evm/loader.rs | 503 ++++++---- src/loader/evm/test.rs | 18 +- src/loader/evm/test/tui.rs | 2 +- src/loader/evm/transcript.rs | 229 ----- src/loader/evm/util.rs | 92 ++ src/loader/halo2.rs | 33 +- src/loader/halo2/accumulation.rs | 93 -- src/loader/halo2/loader.rs | 715 ++++++------- src/loader/halo2/shim.rs | 403 ++++++++ src/loader/halo2/test.rs | 66 ++ src/loader/halo2/transcript.rs | 363 ------- src/loader/native.rs | 87 +- src/loader/native/accumulation.rs | 111 -- src/loader/native/loader.rs | 61 -- src/pcs.rs | 138 +++ src/pcs/kzg.rs | 45 + src/pcs/kzg/accumulation.rs | 196 ++++ src/pcs/kzg/accumulator.rs | 208 ++++ src/pcs/kzg/decider.rs | 162 +++ src/pcs/kzg/multiopen.rs | 5 + src/pcs/kzg/multiopen/bdfg21.rs | 381 +++++++ src/pcs/kzg/multiopen/gwc19.rs | 167 +++ src/protocol.rs | 54 - src/protocol/halo2/test.rs | 176 ---- src/protocol/halo2/test/circuit/maingate.rs | 385 ------- src/protocol/halo2/test/circuit/plookup.rs | 947 ------------------ src/protocol/halo2/test/kzg.rs | 232 ----- src/protocol/halo2/test/kzg/evm.rs | 168 ---- src/protocol/halo2/test/kzg/halo2.rs | 380 ------- src/protocol/halo2/util.rs | 81 -- src/protocol/halo2/util/evm.rs | 142 --- src/protocol/halo2/util/halo2.rs | 212 ---- src/scheme.rs | 1 - src/scheme/kzg.rs | 35 - src/scheme/kzg/accumulation.rs | 171 ---- src/scheme/kzg/accumulation/plonk.rs | 373 ------- src/scheme/kzg/accumulation/shplonk.rs | 593 ----------- src/scheme/kzg/cost.rs | 29 - src/scheme/kzg/msm.rs | 149 --- src/system.rs | 2 + src/{protocol => system}/halo2.rs | 285 ++++-- src/system/halo2/test.rs | 221 ++++ .../halo2/test/circuit.rs | 1 - src/system/halo2/test/circuit/maingate.rs | 111 ++ .../halo2/test/circuit/standard.rs | 56 +- src/system/halo2/test/kzg.rs | 120 +++ src/system/halo2/test/kzg/evm.rs | 138 +++ src/system/halo2/test/kzg/halo2.rs | 372 +++++++ .../halo2/test/kzg/native.rs | 61 +- src/system/halo2/transcript.rs | 82 ++ src/system/halo2/transcript/evm.rs | 400 ++++++++ src/system/halo2/transcript/halo2.rs | 439 ++++++++ src/util.rs | 39 +- src/util/arithmetic.rs | 140 +-- src/util/hash.rs | 6 + src/util/hash/poseidon.rs | 178 ++++ src/util/msm.rs | 203 ++++ src/util/{expression.rs => protocol.rs} | 197 ++-- src/util/transcript.rs | 16 +- src/verifier.rs | 51 + src/verifier/plonk.rs | 464 +++++++++ 73 files changed, 7063 insertions(+), 6232 deletions(-) create mode 100644 examples/evm-verifier-with-accumulator.rs create mode 100644 examples/evm-verifier.rs create mode 100644 src/cost.rs delete mode 100644 src/loader/evm/accumulation.rs delete mode 100644 src/loader/evm/transcript.rs create mode 100644 src/loader/evm/util.rs delete mode 100644 src/loader/halo2/accumulation.rs create mode 100644 src/loader/halo2/shim.rs create mode 100644 src/loader/halo2/test.rs delete mode 100644 src/loader/halo2/transcript.rs delete mode 100644 src/loader/native/accumulation.rs delete mode 100644 src/loader/native/loader.rs create mode 100644 src/pcs.rs create mode 100644 src/pcs/kzg.rs create mode 100644 src/pcs/kzg/accumulation.rs create mode 100644 src/pcs/kzg/accumulator.rs create mode 100644 src/pcs/kzg/decider.rs create mode 100644 src/pcs/kzg/multiopen.rs create mode 100644 src/pcs/kzg/multiopen/bdfg21.rs create mode 100644 src/pcs/kzg/multiopen/gwc19.rs delete mode 100644 src/protocol.rs delete mode 100644 src/protocol/halo2/test.rs delete mode 100644 src/protocol/halo2/test/circuit/maingate.rs delete mode 100644 src/protocol/halo2/test/circuit/plookup.rs delete mode 100644 src/protocol/halo2/test/kzg.rs delete mode 100644 src/protocol/halo2/test/kzg/evm.rs delete mode 100644 src/protocol/halo2/test/kzg/halo2.rs delete mode 100644 src/protocol/halo2/util.rs delete mode 100644 src/protocol/halo2/util/evm.rs delete mode 100644 src/protocol/halo2/util/halo2.rs delete mode 100644 src/scheme.rs delete mode 100644 src/scheme/kzg.rs delete mode 100644 src/scheme/kzg/accumulation.rs delete mode 100644 src/scheme/kzg/accumulation/plonk.rs delete mode 100644 src/scheme/kzg/accumulation/shplonk.rs delete mode 100644 src/scheme/kzg/cost.rs delete mode 100644 src/scheme/kzg/msm.rs create mode 100644 src/system.rs rename src/{protocol => system}/halo2.rs (76%) create mode 100644 src/system/halo2/test.rs rename src/{protocol => system}/halo2/test/circuit.rs (67%) create mode 100644 src/system/halo2/test/circuit/maingate.rs rename src/{protocol => system}/halo2/test/circuit/standard.rs (69%) create mode 100644 src/system/halo2/test/kzg.rs create mode 100644 src/system/halo2/test/kzg/evm.rs create mode 100644 src/system/halo2/test/kzg/halo2.rs rename src/{protocol => system}/halo2/test/kzg/native.rs (50%) create mode 100644 src/system/halo2/transcript.rs create mode 100644 src/system/halo2/transcript/evm.rs create mode 100644 src/system/halo2/transcript/halo2.rs create mode 100644 src/util/hash.rs create mode 100644 src/util/hash/poseidon.rs create mode 100644 src/util/msm.rs rename src/util/{expression.rs => protocol.rs} (65%) create mode 100644 src/verifier.rs create mode 100644 src/verifier/plonk.rs diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 83bfc0bb..a50958d0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -24,7 +24,7 @@ jobs: cache-on-failure: true - name: Run test - run: cargo test --all --features test -- --nocapture + run: cargo test --all --all-features -- --nocapture lint: diff --git a/.gitignore b/.gitignore index ebb68914..0175c775 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,6 @@ .DS_Store /target -fixture +testdata Cargo.lock diff --git a/Cargo.toml b/Cargo.toml index e6ac1d49..de037c4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,57 +1,54 @@ [package] -name = "plonk-verifier" +name = "plonk_verifier" version = "0.1.0" edition = "2021" [dependencies] -ff = "0.12.0" -group = "0.12.0" itertools = "0.10.3" lazy_static = "1.4.0" -num-bigint = "0.4" -num-traits = "0.2" +num-bigint = "0.4.3" +num-integer = "0.1.45" +num-traits = "0.2.15" rand = "0.8" rand_chacha = "0.3.1" -halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.2.1", package = "halo2curves" } - -# halo2 -blake2b_simd = { version = "1.0.0", optional = true } -halo2_proofs = { version = "0.2.0", optional = true } -halo2_wrong = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", package = "halo2wrong", optional = true } -halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", package = "ecc", optional = true } -halo2_wrong_maingate = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", package = "maingate", optional = true } -halo2_wrong_transcript = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", package = "transcript", optional = true } -poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", branch = "padding", optional = true } - -# evm +halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } + +# system_halo2 +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_10_22", optional = true } + +# loader_evm ethereum_types = { package = "ethereum-types", version = "0.13.1", default-features = false, features = ["std"], optional = true } -foundry_evm = { git = "https://github.com/foundry-rs/foundry", package = "foundry-evm", rev = "93ee742d", optional = true } -crossterm = { version = "0.22.1", optional = true } -tui = { version = "0.16.0", default-features = false, features = ["crossterm"], optional = true } sha3 = { version = "0.10.1", optional = true } +# loader_halo2 +halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc", optional = true } +poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true } + [dev-dependencies] paste = "1.0.7" +# system_halo2 +halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc" } + +# loader_evm +foundry_evm = { git = "https://github.com/foundry-rs/foundry", package = "foundry-evm", rev = "6b1ee60e" } +crossterm = { version = "0.22.1" } +tui = { version = "0.16.0", default-features = false, features = ["crossterm"] } + [features] -default = ["halo2", "evm"] -test = ["halo2", "evm"] +default = ["loader_evm", "loader_halo2", "system_halo2"] -halo2 = ["dep:blake2b_simd", "dep:halo2_proofs", "dep:halo2_wrong", "dep:halo2_wrong_ecc", "dep:halo2_wrong_maingate", "dep:halo2_wrong_transcript", "dep:poseidon"] -evm = ["dep:foundry_evm", "dep:crossterm", "dep:tui", "dep:ethereum_types", "dep:sha3"] -sanity-check = [] +loader_evm = ["dep:ethereum_types", "dep:sha3"] +loader_halo2 = ["dep:halo2_proofs", "dep:halo2_wrong_ecc", "dep:poseidon"] -[patch.crates-io] -halo2_proofs = { git = "https://github.com/han0110/halo2", branch = "experiment", package = "halo2_proofs" } +system_halo2 = ["dep:halo2_proofs"] -[patch."https://github.com/privacy-scaling-explorations/halo2"] -halo2_proofs = { git = "https://github.com/han0110/halo2", branch = "experiment", package = "halo2_proofs" } +sanity_check = [] -[patch."https://github.com/privacy-scaling-explorations/halo2curves"] -halo2_curves = { git = "https://github.com//privacy-scaling-explorations/halo2curves", tag = "0.2.1", package = "halo2curves" } +[[example]] +name = "evm-verifier" +required-features = ["loader_evm", "system_halo2"] -[patch."https://github.com/privacy-scaling-explorations/halo2wrong"] -halo2_wrong = { git = "https://github.com/han0110/halo2wrong", branch = "feature/range-chip-with-tagged-table", package = "halo2wrong" } -halo2_wrong_ecc = { git = "https://github.com/han0110/halo2wrong", branch = "feature/range-chip-with-tagged-table", package = "ecc" } -halo2_wrong_maingate = { git = "https://github.com/han0110/halo2wrong", branch = "feature/range-chip-with-tagged-table", package = "maingate" } -halo2_wrong_transcript = { git = "https://github.com/han0110/halo2wrong", branch = "feature/range-chip-with-tagged-table", package = "transcript" } +[[example]] +name = "evm-verifier-with-accumulator" +required-features = ["loader_halo2", "loader_evm", "system_halo2"] diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs new file mode 100644 index 00000000..69def21e --- /dev/null +++ b/examples/evm-verifier-with-accumulator.rs @@ -0,0 +1,619 @@ +use ethereum_types::Address; +use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; +use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use halo2_proofs::{ + dev::MockProver, + plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::{ + commitment::{Params, ParamsProver}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + VerificationStrategy, + }, + transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use itertools::Itertools; +use plonk_verifier::{ + loader::{ + evm::{encode_calldata, EvmLoader}, + native::NativeLoader, + }, + pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, + verifier::{self, PlonkVerifier}, +}; +use rand::rngs::OsRng; +use std::{io::Cursor, rc::Rc}; + +const LIMBS: usize = 4; +const BITS: usize = 68; + +type Pcs = Kzg; +type As = KzgAs; +type Plonk = verifier::Plonk>; + +mod application { + use halo2_curves::bn256::Fr; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + poly::Rotation, + }; + use rand::RngCore; + + #[derive(Clone, Copy)] + pub struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, + } + + impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = + [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { + a, + b, + c, + q_a, + q_b, + q_c, + q_ab, + constant, + instance, + } + } + } + + #[derive(Clone, Default)] + pub struct StandardPlonk(Fr); + + impl StandardPlonk { + pub fn rand(mut rng: R) -> Self { + Self(Fr::from(rng.next_u32() as u64)) + } + + pub fn num_instance() -> Vec { + vec![1] + } + + pub fn instances(&self) -> Vec> { + vec![vec![self.0]] + } + } + + impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; + + region.assign_advice(|| "", config.a, 1, || Value::known(-Fr::from(5)))?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(|| "", column, 1, || Value::known(Fr::from(idx)))?; + } + + let a = region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; + a.copy_advice(|| "", &mut region, config.b, 3)?; + a.copy_advice(|| "", &mut region, config.c, 4)?; + + Ok(()) + }, + ) + } + } +} + +mod aggregation { + use super::{As, Plonk, BITS, LIMBS}; + use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; + use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{self, Circuit, ConstraintSystem}, + poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, + }; + use halo2_wrong_ecc::{ + integer::rns::Rns, + maingate::{ + MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, + RangeInstructions, RegionCtx, + }, + EccConfig, + }; + use itertools::Itertools; + use plonk_verifier::{ + loader::{self, native::NativeLoader}, + pcs::{ + kzg::{KzgAccumulator, KzgSuccinctVerifyingKey}, + AccumulationScheme, AccumulationSchemeProver, + }, + system, + util::arithmetic::{fe_to_limbs, FieldExt}, + verifier::PlonkVerifier, + Protocol, + }; + use rand::rngs::OsRng; + use std::{iter, rc::Rc}; + + const T: usize = 5; + const RATE: usize = 4; + const R_F: usize = 8; + const R_P: usize = 60; + + type Svk = KzgSuccinctVerifyingKey; + type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; + type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + pub type PoseidonTranscript = + system::halo2::transcript::halo2::PoseidonTranscript; + + pub struct Snark { + protocol: Protocol, + instances: Vec>, + proof: Vec, + } + + impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + Self { + protocol, + instances, + proof, + } + } + } + + impl From for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } + } + + #[derive(Clone)] + pub struct SnarkWitness { + protocol: Protocol, + instances: Vec>>, + proof: Value>, + } + + impl SnarkWitness { + fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } + } + + pub fn aggregate<'a>( + svk: &Svk, + loader: &Rc>, + snarks: &[SnarkWitness], + as_proof: Value<&'_ [u8]>, + ) -> KzgAccumulator>> { + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances + .iter() + .map(|instance| loader.assign_scalar(*instance)) + .collect_vec() + }) + .collect_vec() + }; + + let accumulators = snarks + .iter() + .flat_map(|snark| { + let instances = assign_instances(&snark.instances); + let mut transcript = + PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = + Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + }) + .collect_vec(); + + let acccumulator = { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = + As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() + }; + + acccumulator + } + + #[derive(Clone)] + pub struct AggregationConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, + } + + impl AggregationConfig { + pub fn configure( + meta: &mut ConstraintSystem, + composition_bits: Vec, + overflow_bits: Vec, + ) -> Self { + let main_gate_config = MainGate::::configure(meta); + let range_config = + RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); + AggregationConfig { + main_gate_config, + range_config, + } + } + + pub fn main_gate(&self) -> MainGate { + MainGate::new(self.main_gate_config.clone()) + } + + pub fn range_chip(&self) -> RangeChip { + RangeChip::new(self.range_config.clone()) + } + + pub fn ecc_chip(&self) -> BaseFieldEccChip { + BaseFieldEccChip::new(EccConfig::new( + self.range_config.clone(), + self.main_gate_config.clone(), + )) + } + } + + #[derive(Clone)] + pub struct AggregationCircuit { + svk: Svk, + snarks: Vec, + instances: Vec, + as_proof: Value>, + } + + impl AggregationCircuit { + pub fn new(params: &ParamsKZG, snarks: impl IntoIterator) -> Self { + let svk = params.get_g()[0].into(); + let snarks = snarks.into_iter().collect_vec(); + + let accumulators = snarks + .iter() + .flat_map(|snark| { + let mut transcript = + PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }) + .collect_vec(); + + let (accumulator, as_proof) = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .unwrap(); + (accumulator, transcript.finalize()) + }; + + let KzgAccumulator { lhs, rhs } = accumulator; + let instances = [lhs.x, lhs.y, rhs.x, rhs.y] + .map(fe_to_limbs::<_, _, LIMBS, BITS>) + .concat(); + + Self { + svk, + snarks: snarks.into_iter().map_into().collect(), + instances, + as_proof: Value::known(as_proof), + } + } + + pub fn accumulator_indices() -> Vec<(usize, usize)> { + (0..4 * LIMBS).map(|idx| (0, idx)).collect() + } + + pub fn num_instance() -> Vec { + vec![4 * LIMBS] + } + + pub fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + pub fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } + } + + impl Circuit for AggregationCircuit { + type Config = AggregationConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + snarks: self + .snarks + .iter() + .map(SnarkWitness::without_witnesses) + .collect(), + instances: Vec::new(), + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + AggregationConfig::configure( + meta, + vec![BITS / LIMBS], + Rns::::construct().overflow_lengths(), + ) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + + range_chip.load_table(&mut layouter)?; + + let (lhs, rhs) = layouter.assign_region( + || "", + |region| { + let ctx = RegionCtx::new(region, 0); + + let ecc_chip = config.ecc_chip(); + let loader = Halo2Loader::new(ecc_chip, ctx); + let KzgAccumulator { lhs, rhs } = + aggregate(&self.svk, &loader, &self.snarks, self.as_proof()); + + Ok((lhs.assigned(), rhs.assigned())) + }, + )?; + + for (limb, row) in iter::empty() + .chain(lhs.x().limbs()) + .chain(lhs.y().limbs()) + .chain(rhs.x().limbs()) + .chain(rhs.y().limbs()) + .zip(0..) + { + main_gate.expose_public(layouter.namespace(|| ""), limb.into(), row)?; + } + + Ok(()) + } + } +} + +fn gen_srs(k: u32) -> ParamsKZG { + ParamsKZG::::setup(k, OsRng) +} + +fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() +} + +fn gen_proof< + C: Circuit, + E: EncodedChallenge, + TR: TranscriptReadBuffer>, G1Affine, E>, + TW: TranscriptWriterBuffer, G1Affine, E>, +>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, +) -> Vec { + MockProver::run(params.k(), &circuit, instances.clone()) + .unwrap() + .assert_satisfied(); + + let instances = instances + .iter() + .map(|instances| instances.as_slice()) + .collect_vec(); + let proof = { + let mut transcript = TW::init(Vec::new()); + create_proof::, ProverGWC<_>, _, _, TW, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = TR::init(Cursor::new(proof.clone())); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + verify_proof::<_, VerifierGWC<_>, _, TR, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof +} + +fn gen_application_snark(params: &ParamsKZG) -> aggregation::Snark { + let circuit = application::StandardPlonk::rand(OsRng); + + let pk = gen_pk(params, &circuit); + let protocol = compile( + params, + pk.get_vk(), + Config::kzg().with_num_instance(application::StandardPlonk::num_instance()), + ); + + let proof = gen_proof::< + _, + _, + aggregation::PoseidonTranscript, + aggregation::PoseidonTranscript, + >(params, &pk, circuit.clone(), circuit.instances()); + aggregation::Snark::new(protocol, circuit.instances(), proof) +} + +fn gen_aggregation_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + num_instance: Vec, + accumulator_indices: Vec<(usize, usize)>, +) -> Vec { + let svk = params.get_g()[0].into(); + let dk = (params.g2(), params.s_g2()).into(); + let protocol = compile( + params, + vk, + Config::kzg() + .with_num_instance(num_instance.clone()) + .with_accumulator_indices(accumulator_indices), + ); + + let loader = EvmLoader::new::(); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + + let instances = transcript.load_instances(num_instance); + let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); + + loader.deployment_code() +} + +fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { + let calldata = encode_calldata(&instances, &proof); + let success = { + let mut evm = ExecutorBuilder::default() + .with_gas_limit(u64::MAX.into()) + .build(Backend::new(MultiFork::new().0, None)); + + let caller = Address::from_low_u64_be(0xfe); + let verifier = evm + .deploy(caller, deployment_code.into(), 0.into(), None) + .unwrap() + .address; + let result = evm + .call_raw(caller, verifier, calldata.into(), 0.into()) + .unwrap(); + + dbg!(result.gas_used); + + !result.reverted + }; + assert!(success); +} + +fn main() { + let params = gen_srs(22); + let params_app = { + let mut params = params.clone(); + params.downsize(8); + params + }; + + let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app)); + let agg_circuit = aggregation::AggregationCircuit::new(¶ms, snarks); + let pk = gen_pk(¶ms, &agg_circuit); + let deployment_code = gen_aggregation_evm_verifier( + ¶ms, + pk.get_vk(), + aggregation::AggregationCircuit::num_instance(), + aggregation::AggregationCircuit::accumulator_indices(), + ); + + let proof = gen_proof::<_, _, EvmTranscript, EvmTranscript>( + ¶ms, + &pk, + agg_circuit.clone(), + agg_circuit.instances(), + ); + evm_verify(deployment_code, agg_circuit.instances(), proof); +} diff --git a/examples/evm-verifier.rs b/examples/evm-verifier.rs new file mode 100644 index 00000000..b51a9a30 --- /dev/null +++ b/examples/evm-verifier.rs @@ -0,0 +1,260 @@ +use ethereum_types::Address; +use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; +use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{ + create_proof, keygen_pk, keygen_vk, verify_proof, Advice, Circuit, Column, + ConstraintSystem, Error, Fixed, Instance, ProvingKey, VerifyingKey, + }, + poly::{ + commitment::{Params, ParamsProver}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + Rotation, VerificationStrategy, + }, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use itertools::Itertools; +use plonk_verifier::{ + loader::evm::{encode_calldata, EvmLoader}, + pcs::kzg::{Gwc19, Kzg}, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, + verifier::{self, PlonkVerifier}, +}; +use rand::{rngs::OsRng, RngCore}; +use std::rc::Rc; + +type Plonk = verifier::Plonk>; + +#[derive(Clone, Copy)] +struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, +} + +impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { + a, + b, + c, + q_a, + q_b, + q_c, + q_ab, + constant, + instance, + } + } +} + +#[derive(Clone, Default)] +struct StandardPlonk(Fr); + +impl StandardPlonk { + fn rand(mut rng: R) -> Self { + Self(Fr::from(rng.next_u32() as u64)) + } + + fn num_instance() -> Vec { + vec![1] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0]] + } +} + +impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; + + region.assign_advice(|| "", config.a, 1, || Value::known(-Fr::from(5)))?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(|| "", column, 1, || Value::known(Fr::from(idx)))?; + } + + let a = region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; + a.copy_advice(|| "", &mut region, config.b, 3)?; + a.copy_advice(|| "", &mut region, config.c, 4)?; + + Ok(()) + }, + ) + } +} + +fn gen_srs(k: u32) -> ParamsKZG { + ParamsKZG::::setup(k, OsRng) +} + +fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() +} + +fn gen_proof>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, +) -> Vec { + MockProver::run(params.k(), &circuit, instances.clone()) + .unwrap() + .assert_satisfied(); + + let instances = instances + .iter() + .map(|instances| instances.as_slice()) + .collect_vec(); + let proof = { + let mut transcript = TranscriptWriterBuffer::<_, G1Affine, _>::init(Vec::new()); + create_proof::, ProverGWC<_>, _, _, EvmTranscript<_, _, _, _>, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = TranscriptReadBuffer::<_, G1Affine, _>::init(proof.as_slice()); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + verify_proof::<_, VerifierGWC<_>, _, EvmTranscript<_, _, _, _>, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof +} + +fn gen_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + num_instance: Vec, +) -> Vec { + let svk = params.get_g()[0].into(); + let dk = (params.g2(), params.s_g2()).into(); + let protocol = compile( + params, + vk, + Config::kzg().with_num_instance(num_instance.clone()), + ); + + let loader = EvmLoader::new::(); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + + let instances = transcript.load_instances(num_instance); + let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); + + loader.deployment_code() +} + +fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { + let calldata = encode_calldata(&instances, &proof); + let success = { + let mut evm = ExecutorBuilder::default() + .with_gas_limit(u64::MAX.into()) + .build(Backend::new(MultiFork::new().0, None)); + + let caller = Address::from_low_u64_be(0xfe); + let verifier = evm + .deploy(caller, deployment_code.into(), 0.into(), None) + .unwrap() + .address; + let result = evm + .call_raw(caller, verifier, calldata.into(), 0.into()) + .unwrap(); + + dbg!(result.gas_used); + + !result.reverted + }; + assert!(success); +} + +fn main() { + let params = gen_srs(8); + + let circuit = StandardPlonk::rand(OsRng); + let pk = gen_pk(¶ms, &circuit); + let deployment_code = gen_evm_verifier(¶ms, pk.get_vk(), StandardPlonk::num_instance()); + + let proof = gen_proof(¶ms, &pk, circuit.clone(), circuit.instances()); + evm_verify(deployment_code, circuit.instances(), proof); +} diff --git a/rust-toolchain b/rust-toolchain index db84486f..7cc6ef41 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2022-06-01 \ No newline at end of file +1.63.0 \ No newline at end of file diff --git a/src/cost.rs b/src/cost.rs new file mode 100644 index 00000000..b085aed8 --- /dev/null +++ b/src/cost.rs @@ -0,0 +1,44 @@ +use std::ops::Add; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Cost { + pub num_instance: usize, + pub num_commitment: usize, + pub num_evaluation: usize, + pub num_msm: usize, +} + +impl Cost { + pub fn new( + num_instance: usize, + num_commitment: usize, + num_evaluation: usize, + num_msm: usize, + ) -> Self { + Self { + num_instance, + num_commitment, + num_evaluation, + num_msm, + } + } +} + +impl Add for Cost { + type Output = Cost; + + fn add(self, rhs: Cost) -> Self::Output { + Cost::new( + self.num_instance + rhs.num_instance, + self.num_commitment + rhs.num_commitment, + self.num_evaluation + rhs.num_evaluation, + self.num_msm + rhs.num_msm, + ) + } +} + +pub trait CostEstimation { + type Input; + + fn estimate_cost(input: &Self::Input) -> Cost; +} diff --git a/src/lib.rs b/src/lib.rs index b193cc5c..4e8da5fb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,19 +1,38 @@ -#![feature(int_log)] -#![feature(int_roundings)] -#![feature(assert_matches)] #![allow(clippy::type_complexity)] #![allow(clippy::too_many_arguments)] #![allow(clippy::upper_case_acronyms)] +pub mod cost; pub mod loader; -pub mod protocol; -pub mod scheme; +pub mod pcs; +pub mod system; pub mod util; +pub mod verifier; #[derive(Clone, Debug)] pub enum Error { InvalidInstances, - MissingQuery(util::Query), - MissingChallenge(usize), + InvalidLinearization, + InvalidQuery(util::protocol::Query), + InvalidChallenge(usize), + AssertionFailure(String), Transcript(std::io::ErrorKind, String), } + +#[derive(Clone, Debug)] +pub struct Protocol { + // Common description + pub domain: util::arithmetic::Domain, + pub preprocessed: Vec, + pub num_instance: Vec, + pub num_witness: Vec, + pub num_challenge: Vec, + pub evaluations: Vec, + pub queries: Vec, + pub quotient: util::protocol::QuotientPolynomial, + // Minor customization + pub transcript_initial_state: Option, + pub instance_committing_key: Option>, + pub linearization: Option, + pub accumulator_indices: Vec>, +} diff --git a/src/loader.rs b/src/loader.rs index ebca0ec9..8c39bae0 100644 --- a/src/loader.rs +++ b/src/loader.rs @@ -1,15 +1,21 @@ -use crate::util::{Curve, FieldOps, GroupOps, Itertools, PrimeField}; +use crate::{ + util::{ + arithmetic::{CurveAffine, FieldOps, PrimeField}, + Itertools, + }, + Error, +}; use std::{fmt::Debug, iter}; pub mod native; -#[cfg(feature = "evm")] +#[cfg(feature = "loader_evm")] pub mod evm; -#[cfg(feature = "halo2")] +#[cfg(feature = "loader_halo2")] pub mod halo2; -pub trait LoadedEcPoint: Clone + Debug + GroupOps + PartialEq { +pub trait LoadedEcPoint: Clone + Debug + PartialEq { type Loader: Loader; fn loader(&self) -> &Self::Loader; @@ -24,67 +30,11 @@ pub trait LoadedEcPoint: Clone + Debug + GroupOps + PartialEq { ) -> Self; } -pub trait LoadedScalar: Clone + Debug + FieldOps { +pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { type Loader: ScalarLoader; fn loader(&self) -> &Self::Loader; - fn sum_with_coeff_and_constant(values: &[(F, Self)], constant: &F) -> Self { - assert!(!values.is_empty()); - - let loader = values.first().unwrap().1.loader(); - iter::empty() - .chain(if *constant == F::zero() { - None - } else { - Some(loader.load_const(constant)) - }) - .chain( - values - .iter() - .map(|(coeff, value)| loader.load_const(coeff) * value), - ) - .reduce(|acc, term| acc + term) - .unwrap() - } - - fn sum_products_with_coeff_and_constant(values: &[(F, Self, Self)], constant: &F) -> Self { - assert!(!values.is_empty()); - - let loader = values.first().unwrap().1.loader(); - iter::empty() - .chain(if *constant == F::zero() { - None - } else { - Some(loader.load_const(constant)) - }) - .chain( - values - .iter() - .map(|(coeff, lhs, rhs)| loader.load_const(coeff) * lhs * rhs), - ) - .reduce(|acc, term| acc + term) - .unwrap() - } - - fn sum_with_coeff(values: &[(F, Self)]) -> Self { - Self::sum_with_coeff_and_constant(values, &F::zero()) - } - - fn sum_with_const(values: &[Self], constant: &F) -> Self { - Self::sum_with_coeff_and_constant( - &values - .iter() - .map(|value| (F::one(), value.clone())) - .collect_vec(), - constant, - ) - } - - fn sum(values: &[Self]) -> Self { - Self::sum_with_const(values, &F::zero()) - } - fn square(&self) -> Self { self.clone() * self } @@ -133,7 +83,7 @@ pub trait LoadedScalar: Clone + Debug + FieldOps { } } -pub trait EcPointLoader { +pub trait EcPointLoader { type LoadedEcPoint: LoadedEcPoint; fn ec_point_load_const(&self, value: &C) -> Self::LoadedEcPoint; @@ -145,6 +95,13 @@ pub trait EcPointLoader { fn ec_point_load_one(&self) -> Self::LoadedEcPoint { self.ec_point_load_const(&C::generator()) } + + fn ec_point_assert_eq( + &self, + annotation: &str, + lhs: &Self::LoadedEcPoint, + rhs: &Self::LoadedEcPoint, + ) -> Result<(), Error>; } pub trait ScalarLoader { @@ -159,9 +116,121 @@ pub trait ScalarLoader { fn load_one(&self) -> Self::LoadedScalar { self.load_const(&F::one()) } + + fn assert_eq( + &self, + annotation: &str, + lhs: &Self::LoadedScalar, + rhs: &Self::LoadedScalar, + ) -> Result<(), Error>; + + fn sum_with_coeff_and_const( + &self, + values: &[(F, &Self::LoadedScalar)], + constant: F, + ) -> Self::LoadedScalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let loader = values.first().unwrap().1.loader(); + iter::empty() + .chain(if constant == F::zero() { + None + } else { + Some(loader.load_const(&constant)) + }) + .chain(values.iter().map(|&(coeff, value)| { + if coeff == F::one() { + value.clone() + } else { + loader.load_const(&coeff) * value + } + })) + .reduce(|acc, term| acc + term) + .unwrap() + } + + fn sum_products_with_coeff_and_const( + &self, + values: &[(F, &Self::LoadedScalar, &Self::LoadedScalar)], + constant: F, + ) -> Self::LoadedScalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let loader = values.first().unwrap().1.loader(); + iter::empty() + .chain(if constant == F::zero() { + None + } else { + Some(loader.load_const(&constant)) + }) + .chain(values.iter().map(|&(coeff, lhs, rhs)| { + if coeff == F::one() { + lhs.clone() * rhs + } else { + loader.load_const(&coeff) * lhs * rhs + } + })) + .reduce(|acc, term| acc + term) + .unwrap() + } + + fn sum_with_coeff(&self, values: &[(F, &Self::LoadedScalar)]) -> Self::LoadedScalar { + self.sum_with_coeff_and_const(values, F::zero()) + } + + fn sum_with_const(&self, values: &[&Self::LoadedScalar], constant: F) -> Self::LoadedScalar { + self.sum_with_coeff_and_const( + &values.iter().map(|&value| (F::one(), value)).collect_vec(), + constant, + ) + } + + fn sum(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { + self.sum_with_const(values, F::zero()) + } + + fn sum_products_with_coeff( + &self, + values: &[(F, &Self::LoadedScalar, &Self::LoadedScalar)], + ) -> Self::LoadedScalar { + self.sum_products_with_coeff_and_const(values, F::zero()) + } + + fn sum_products_with_const( + &self, + values: &[(&Self::LoadedScalar, &Self::LoadedScalar)], + constant: F, + ) -> Self::LoadedScalar { + self.sum_products_with_coeff_and_const( + &values + .iter() + .map(|&(lhs, rhs)| (F::one(), lhs, rhs)) + .collect_vec(), + constant, + ) + } + + fn sum_products( + &self, + values: &[(&Self::LoadedScalar, &Self::LoadedScalar)], + ) -> Self::LoadedScalar { + self.sum_products_with_const(values, F::zero()) + } + + fn product(&self, values: &[&Self::LoadedScalar]) -> Self::LoadedScalar { + values + .iter() + .fold(self.load_one(), |acc, value| acc * *value) + } } -pub trait Loader: EcPointLoader + ScalarLoader + Clone { +pub trait Loader: + EcPointLoader + ScalarLoader + Clone + Debug +{ fn start_cost_metering(&self, _: &str) {} fn end_cost_metering(&self) {} diff --git a/src/loader/evm.rs b/src/loader/evm.rs index e0754532..7a07670c 100644 --- a/src/loader/evm.rs +++ b/src/loader/evm.rs @@ -1,70 +1,14 @@ -use crate::{ - scheme::kzg::Cost, - util::{Itertools, PrimeField}, -}; -use ethereum_types::U256; -use std::iter; - -mod accumulation; mod code; -mod loader; -mod transcript; +pub(crate) mod loader; +mod util; #[cfg(test)] mod test; -pub use loader::EvmLoader; -pub use transcript::EvmTranscript; +pub use loader::{EcPoint, EvmLoader, Scalar}; +pub use util::{encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, MemoryChunk}; + +pub use ethereum_types::U256; #[cfg(test)] pub use test::execute; - -// Assert F::Repr in little-endian -pub fn field_to_u256(f: &F) -> U256 -where - F: PrimeField, -{ - U256::from_little_endian(f.to_repr().as_ref()) -} - -pub fn u256_to_field(value: U256) -> F -where - F: PrimeField, -{ - let value = value % (field_to_u256(&-F::one()) + 1u64); - let mut repr = F::Repr::default(); - value.to_little_endian(repr.as_mut()); - F::from_repr(repr).unwrap() -} - -pub fn modulus() -> U256 -where - F: PrimeField, -{ - U256::from_little_endian((-F::one()).to_repr().as_ref()) + 1 -} - -pub fn encode_calldata(instances: Vec>, proof: Vec) -> Vec -where - F: PrimeField, -{ - iter::empty() - .chain( - instances - .into_iter() - .flatten() - .flat_map(|value| value.to_repr().as_ref().iter().rev().cloned().collect_vec()), - ) - .chain(proof) - .collect() -} - -pub fn estimate_gas(cost: Cost) -> usize { - let proof_size = cost.num_commitment * 64 + (cost.num_evaluation + cost.num_statement) * 32; - - let intrinsic_cost = 21000; - let calldata_cost = (proof_size as f64 * 15.25).ceil() as usize; - let ec_operation_cost = 113100 + (cost.num_msm - 2) * 6350; - - intrinsic_cost + calldata_cost + ec_operation_cost -} diff --git a/src/loader/evm/accumulation.rs b/src/loader/evm/accumulation.rs deleted file mode 100644 index cfe4be6f..00000000 --- a/src/loader/evm/accumulation.rs +++ /dev/null @@ -1,98 +0,0 @@ -use crate::{ - loader::evm::loader::{EvmLoader, Scalar}, - protocol::Protocol, - scheme::kzg::{AccumulationStrategy, Accumulator, SameCurveAccumulation, MSM}, - util::{Curve, Itertools, PrimeCurveAffine, PrimeField, Transcript, UncompressedEncoding}, - Error, -}; -use ethereum_types::U256; -use halo2_curves::{ - bn256::{G1Affine, G2Affine, G1}, - CurveAffine, -}; -use std::{ops::Neg, rc::Rc}; - -impl SameCurveAccumulation, LIMBS, BITS> { - pub fn code(self, g1: G1Affine, g2: G2Affine, s_g2: G2Affine) -> Vec { - let (lhs, rhs) = self.accumulator.unwrap().evaluate(g1.to_curve()); - let loader = lhs.loader(); - - let [g2, minus_s_g2] = [g2, s_g2.neg()].map(|ec_point| { - let coordinates = ec_point.coordinates().unwrap(); - let x = coordinates.x().to_repr(); - let y = coordinates.y().to_repr(); - ( - U256::from_little_endian(&x.as_ref()[32..]), - U256::from_little_endian(&x.as_ref()[..32]), - U256::from_little_endian(&y.as_ref()[32..]), - U256::from_little_endian(&y.as_ref()[..32]), - ) - }); - loader.pairing(&lhs, g2, &rhs, minus_s_g2); - - loader.code() - } -} - -impl - AccumulationStrategy, T, P> - for SameCurveAccumulation, LIMBS, BITS> -where - C::Scalar: PrimeField, - C: UncompressedEncoding, - T: Transcript>, -{ - type Output = (); - - fn extract_accumulator( - &self, - protocol: &Protocol, - loader: &Rc, - transcript: &mut T, - statements: &[Vec], - ) -> Option>> { - let accumulator_indices = protocol.accumulator_indices.as_ref()?; - - let num_statements = statements - .iter() - .map(|statements| statements.len()) - .collect_vec(); - - let challenges = transcript.squeeze_n_challenges(accumulator_indices.len()); - let accumulators = accumulator_indices - .iter() - .map(|indices| { - assert_eq!(indices.len(), 4 * LIMBS); - assert!(indices - .iter() - .enumerate() - .all(|(idx, index)| indices[0] == (index.0, index.1 - idx))); - let offset = - (num_statements[..indices[0].0].iter().sum::() + indices[0].1) * 0x20; - let lhs = loader.calldataload_ec_point_from_limbs::(offset); - let rhs = loader.calldataload_ec_point_from_limbs::(offset + 0x100); - Accumulator::new(MSM::base(lhs), MSM::base(rhs)) - }) - .collect_vec(); - - Some(Accumulator::random_linear_combine( - challenges.into_iter().zip(accumulators), - )) - } - - fn process( - &mut self, - _: &Rc, - transcript: &mut T, - _: P, - accumulator: Accumulator>, - ) -> Result { - self.accumulator = Some(match self.accumulator.take() { - Some(curr_accumulator) => { - accumulator + curr_accumulator * &transcript.squeeze_challenge() - } - None => accumulator, - }); - Ok(()) - } -} diff --git a/src/loader/evm/code.rs b/src/loader/evm/code.rs index 38069401..80dd5c71 100644 --- a/src/loader/evm/code.rs +++ b/src/loader/evm/code.rs @@ -1,6 +1,6 @@ use crate::util::Itertools; use ethereum_types::U256; -use foundry_evm::{revm::opcode::*, HashMap}; +use std::{collections::HashMap, iter}; pub enum Precompiled { BigModExp = 0x05, @@ -36,6 +36,37 @@ impl Code { code } + pub fn deployment(code: Vec) -> Vec { + let code_len = code.len(); + assert_ne!(code_len, 0); + + iter::empty() + .chain([ + PUSH1 + 1, + (code_len >> 8) as u8, + (code_len & 0xff) as u8, + PUSH1, + 14, + PUSH1, + 0, + CODECOPY, + ]) + .chain([ + PUSH1 + 1, + (code_len >> 8) as u8, + (code_len & 0xff) as u8, + PUSH1, + 0, + RETURN, + ]) + .chain(code) + .collect() + } + + pub fn stack_len(&self) -> usize { + self.stack_len + } + pub fn len(&self) -> usize { self.code.len() } @@ -180,3 +211,85 @@ impl_opcodes!( revert -> (REVERT, -2) selfdestruct -> (SELFDESTRUCT, -1) ); + +const STOP: u8 = 0x00; +const ADD: u8 = 0x01; +const MUL: u8 = 0x02; +const SUB: u8 = 0x03; +const DIV: u8 = 0x04; +const SDIV: u8 = 0x05; +const MOD: u8 = 0x06; +const SMOD: u8 = 0x07; +const ADDMOD: u8 = 0x08; +const MULMOD: u8 = 0x09; +const EXP: u8 = 0x0A; +const SIGNEXTEND: u8 = 0x0B; +const LT: u8 = 0x10; +const GT: u8 = 0x11; +const SLT: u8 = 0x12; +const SGT: u8 = 0x13; +const EQ: u8 = 0x14; +const ISZERO: u8 = 0x15; +const AND: u8 = 0x16; +const OR: u8 = 0x17; +const XOR: u8 = 0x18; +const NOT: u8 = 0x19; +const BYTE: u8 = 0x1A; +const SHL: u8 = 0x1B; +const SHR: u8 = 0x1C; +const SAR: u8 = 0x1D; +const SHA3: u8 = 0x20; +const ADDRESS: u8 = 0x30; +const BALANCE: u8 = 0x31; +const ORIGIN: u8 = 0x32; +const CALLER: u8 = 0x33; +const CALLVALUE: u8 = 0x34; +const CALLDATALOAD: u8 = 0x35; +const CALLDATASIZE: u8 = 0x36; +const CALLDATACOPY: u8 = 0x37; +const CODESIZE: u8 = 0x38; +const CODECOPY: u8 = 0x39; +const GASPRICE: u8 = 0x3A; +const EXTCODESIZE: u8 = 0x3B; +const EXTCODECOPY: u8 = 0x3C; +const RETURNDATASIZE: u8 = 0x3D; +const RETURNDATACOPY: u8 = 0x3E; +const EXTCODEHASH: u8 = 0x3F; +const BLOCKHASH: u8 = 0x40; +const COINBASE: u8 = 0x41; +const TIMESTAMP: u8 = 0x42; +const NUMBER: u8 = 0x43; +const DIFFICULTY: u8 = 0x44; +const GASLIMIT: u8 = 0x45; +const CHAINID: u8 = 0x46; +const SELFBALANCE: u8 = 0x47; +const BASEFEE: u8 = 0x48; +const POP: u8 = 0x50; +const MLOAD: u8 = 0x51; +const MSTORE: u8 = 0x52; +const MSTORE8: u8 = 0x53; +const SLOAD: u8 = 0x54; +const SSTORE: u8 = 0x55; +const JUMP: u8 = 0x56; +const JUMPI: u8 = 0x57; +const PC: u8 = 0x58; +const MSIZE: u8 = 0x59; +const GAS: u8 = 0x5A; +const JUMPDEST: u8 = 0x5B; +const PUSH1: u8 = 0x60; +const DUP1: u8 = 0x80; +const SWAP1: u8 = 0x90; +const LOG0: u8 = 0xA0; +const LOG1: u8 = 0xA1; +const LOG2: u8 = 0xA2; +const LOG3: u8 = 0xA3; +const LOG4: u8 = 0xA4; +const CREATE: u8 = 0xF0; +const CALL: u8 = 0xF1; +const CALLCODE: u8 = 0xF2; +const RETURN: u8 = 0xF3; +const DELEGATECALL: u8 = 0xF4; +const CREATE2: u8 = 0xF5; +const STATICCALL: u8 = 0xFA; +const REVERT: u8 = 0xFD; +const SELFDESTRUCT: u8 = 0xFF; diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs index 258ffdd0..06ab7fd8 100644 --- a/src/loader/evm/loader.rs +++ b/src/loader/evm/loader.rs @@ -1,24 +1,49 @@ use crate::{ loader::evm::{ code::{Code, Precompiled}, - modulus, + fe_to_u256, modulus, }, - loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::{Curve, FieldOps, Itertools, PrimeField, UncompressedEncoding}, + loader::{evm::u256_to_fe, EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, + util::{ + arithmetic::{CurveAffine, FieldOps, PrimeField}, + Itertools, + }, + Error, }; use ethereum_types::{U256, U512}; use std::{ cell::RefCell, + collections::HashMap, fmt::{self, Debug}, iter, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, + ops::{Add, AddAssign, DerefMut, Mul, MulAssign, Neg, Sub, SubAssign}, rc::Rc, }; -#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub enum Value { Constant(T), Memory(usize), + Negated(Box>), + Sum(Box>, Box>), + Product(Box>, Box>), +} + +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + self.identifier() == other.identifier() + } +} + +impl Value { + fn identifier(&self) -> String { + match &self { + Value::Constant(_) | Value::Memory(_) => format!("{:?}", self), + Value::Negated(value) => format!("-({:?})", value), + Value::Sum(lhs, rhs) => format!("({:?} + {:?})", lhs, rhs), + Value::Product(lhs, rhs) => format!("({:?} * {:?})", lhs, rhs), + } + } } #[derive(Clone, Debug)] @@ -27,6 +52,7 @@ pub struct EvmLoader { scalar_modulus: U256, code: RefCell, ptr: RefCell, + cache: RefCell>, #[cfg(test)] gas_metering_ids: RefCell>, } @@ -46,13 +72,18 @@ impl EvmLoader { base_modulus, scalar_modulus, code: RefCell::new(code), - ptr: RefCell::new(0), + ptr: Default::default(), + cache: Default::default(), #[cfg(test)] gas_metering_ids: RefCell::new(Vec::new()), }) } - pub fn code(self: &Rc) -> Vec { + pub fn deployment_code(self: &Rc) -> Vec { + Code::deployment(self.runtime_code()) + } + + pub fn runtime_code(self: &Rc) -> Vec { let mut code = self.code.borrow().clone(); let dst = code.len() + 9; code.push(dst) @@ -72,7 +103,41 @@ impl EvmLoader { ptr } - fn scalar(self: &Rc, value: Value) -> Scalar { + pub(crate) fn scalar_modulus(&self) -> U256 { + self.scalar_modulus + } + + pub(crate) fn ptr(&self) -> usize { + *self.ptr.borrow() + } + + pub(crate) fn code_mut(&self) -> impl DerefMut + '_ { + self.code.borrow_mut() + } + + pub(crate) fn scalar(self: &Rc, value: Value) -> Scalar { + let value = if matches!( + value, + Value::Constant(_) | Value::Memory(_) | Value::Negated(_) + ) { + value + } else { + let identifier = value.identifier(); + let some_ptr = self.cache.borrow().get(&identifier).cloned(); + let ptr = if let Some(ptr) = some_ptr { + ptr + } else { + self.push(&Scalar { + loader: self.clone(), + value, + }); + let ptr = self.allocate(0x20); + self.code.borrow_mut().push(ptr).mstore(); + self.cache.borrow_mut().insert(identifier, ptr); + ptr + }; + Value::Memory(ptr) + }; Scalar { loader: self.clone(), value, @@ -87,13 +152,29 @@ impl EvmLoader { } fn push(self: &Rc, scalar: &Scalar) { - match scalar.value { + match scalar.value.clone() { Value::Constant(constant) => { self.code.borrow_mut().push(constant); } Value::Memory(ptr) => { self.code.borrow_mut().push(ptr).mload(); } + Value::Negated(value) => { + self.push(&self.scalar(*value)); + self.code.borrow_mut().push(self.scalar_modulus).sub(); + } + Value::Sum(lhs, rhs) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(&self.scalar(*lhs)); + self.push(&self.scalar(*rhs)); + self.code.borrow_mut().addmod(); + } + Value::Product(lhs, rhs) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(&self.scalar(*lhs)); + self.push(&self.scalar(*rhs)); + self.code.borrow_mut().mulmod(); + } } } @@ -139,46 +220,36 @@ impl EvmLoader { self.ec_point(Value::Memory(ptr)) } - pub fn calldataload_ec_point_from_limbs( + pub fn ec_point_from_limbs( self: &Rc, - offset: usize, + x_limbs: [Scalar; LIMBS], + y_limbs: [Scalar; LIMBS], ) -> EcPoint { let ptr = self.allocate(0x40); - for (ptr, offset) in [(ptr, offset), (ptr + 0x20, offset + LIMBS * 0x20)] { - for idx in 0..LIMBS { - if idx == 0 { - self.code - .borrow_mut() - // [..., success] - .push(offset) - // [..., success, x_limb_0_ptr] - .calldataload(); - // [..., success, x_limb_0] - } else { + for (ptr, limbs) in [(ptr, x_limbs), (ptr + 0x20, y_limbs)] { + for (idx, limb) in limbs.into_iter().enumerate() { + self.push(&limb); + // [..., success, acc] + if idx > 0 { self.code .borrow_mut() - // [..., success, x_acc] - .push(offset + idx * 0x20) - // [..., success, x_acc, x_limb_i_ptr] - .calldataload() - // [..., success, x_acc, x_limb_i] .push(idx * BITS) - // [..., success, x_acc, x_limb_i, shift] + // [..., success, acc, limb_i, shift] .shl() - // [..., success, x_acc, x_limb_i << shift] + // [..., success, acc, limb_i << shift] .add(); - // [..., success, x_acc] + // [..., success, acc] } } self.code .borrow_mut() - // [..., success, x] + // [..., success, coordinate] .dup(0) - // [..., success, x, x] + // [..., success, coordinate, coordinate] .push(ptr) - // [..., success, x, x, x_ptr] + // [..., success, coordinate, coordinate, ptr] .mstore(); - // [..., success, x] + // [..., success, coordinate] } // [..., success, x, y] self.validate_ec_point(); @@ -258,60 +329,21 @@ impl EvmLoader { .and(); } - pub fn squeeze_challenge(self: &Rc, ptr: usize, len: usize) -> (usize, Scalar) { - assert!(len > 0 && len % 0x20 == 0); - - let (ptr, len) = if len == 0x20 { - let ptr = if ptr + len != *self.ptr.borrow() { - (ptr..ptr + len) - .step_by(0x20) - .map(|ptr| self.dup_scalar(&self.scalar(Value::Memory(ptr)))) - .collect_vec() - .first() - .unwrap() - .ptr() - } else { - ptr - }; - self.code.borrow_mut().push(1).push(ptr + 0x20).mstore8(); - (ptr, len + 1) - } else { - (ptr, len) - }; - - let challenge_ptr = self.allocate(0x20); + pub fn keccak256(self: &Rc, ptr: usize, len: usize) -> usize { let hash_ptr = self.allocate(0x20); - self.code .borrow_mut() - .push(self.scalar_modulus) .push(len) .push(ptr) .keccak256() - .dup(0) .push(hash_ptr) - .mstore() - .r#mod() - .push(challenge_ptr) .mstore(); - - (hash_ptr, self.scalar(Value::Memory(challenge_ptr))) + hash_ptr } pub fn copy_scalar(self: &Rc, scalar: &Scalar, ptr: usize) { - match scalar.value { - Value::Constant(constant) => { - self.code.borrow_mut().push(constant).push(ptr).mstore(); - } - Value::Memory(src_ptr) => { - self.code - .borrow_mut() - .push(src_ptr) - .mload() - .push(ptr) - .mstore(); - } - } + self.push(scalar); + self.code.borrow_mut().push(ptr).mstore(); } pub fn dup_scalar(self: &Rc, scalar: &Scalar) -> Scalar { @@ -320,7 +352,7 @@ impl EvmLoader { self.scalar(Value::Memory(ptr)) } - fn dup_ec_point(self: &Rc, value: &EcPoint) -> EcPoint { + pub fn dup_ec_point(self: &Rc, value: &EcPoint) -> EcPoint { let ptr = self.allocate(0x40); match value.value { Value::Constant((x, y)) => { @@ -345,6 +377,9 @@ impl EvmLoader { .push(ptr + 0x20) .mstore(); } + Value::Negated(_) | Value::Sum(_, _) | Value::Product(_, _) => { + unreachable!() + } } self.ec_point(Value::Memory(ptr)) } @@ -390,14 +425,6 @@ impl EvmLoader { self.ec_point(Value::Memory(rd_ptr)) } - fn ec_point_sub(self: &Rc, _: &EcPoint, _: &EcPoint) -> EcPoint { - unreachable!() - } - - fn ec_point_neg(self: &Rc, _: &EcPoint) -> EcPoint { - unreachable!() - } - fn ec_point_scalar_mul(self: &Rc, ec_point: &EcPoint, scalar: &Scalar) -> EcPoint { let rd_ptr = self.dup_ec_point(ec_point).ptr(); self.dup_scalar(scalar); @@ -449,19 +476,15 @@ impl EvmLoader { } fn add(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { - if let (Value::Constant(lhs), Value::Constant(rhs)) = (lhs.value, rhs.value) { + if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) { let out = (U512::from(lhs) + U512::from(rhs)) % U512::from(self.scalar_modulus); return self.scalar(Value::Constant(out.try_into().unwrap())); } - let ptr = self.allocate(0x20); - - self.code.borrow_mut().push(self.scalar_modulus); - self.push(rhs); - self.push(lhs); - self.code.borrow_mut().addmod().push(ptr).mstore(); - - self.scalar(Value::Memory(ptr)) + self.scalar(Value::Sum( + Box::new(lhs.value.clone()), + Box::new(rhs.value.clone()), + )) } fn sub(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { @@ -469,31 +492,22 @@ impl EvmLoader { return self.add(lhs, &self.neg(rhs)); } - let ptr = self.allocate(0x20); - - self.code.borrow_mut().push(self.scalar_modulus); - self.push(rhs); - self.code.borrow_mut().push(self.scalar_modulus).sub(); - self.push(lhs); - self.code.borrow_mut().addmod().push(ptr).mstore(); - - self.scalar(Value::Memory(ptr)) + self.scalar(Value::Sum( + Box::new(lhs.value.clone()), + Box::new(Value::Negated(Box::new(rhs.value.clone()))), + )) } fn mul(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { - if let (Value::Constant(lhs), Value::Constant(rhs)) = (lhs.value, rhs.value) { + if let (Value::Constant(lhs), Value::Constant(rhs)) = (&lhs.value, &rhs.value) { let out = (U512::from(lhs) * U512::from(rhs)) % U512::from(self.scalar_modulus); return self.scalar(Value::Constant(out.try_into().unwrap())); } - let ptr = self.allocate(0x20); - - self.code.borrow_mut().push(self.scalar_modulus); - self.push(rhs); - self.push(lhs); - self.code.borrow_mut().mulmod().push(ptr).mstore(); - - self.scalar(Value::Memory(ptr)) + self.scalar(Value::Product( + Box::new(lhs.value.clone()), + Box::new(rhs.value.clone()), + )) } fn neg(self: &Rc, scalar: &Scalar) -> Scalar { @@ -501,17 +515,7 @@ impl EvmLoader { return self.scalar(Value::Constant(self.scalar_modulus - constant)); } - let ptr = self.allocate(0x20); - - self.push(scalar); - self.code - .borrow_mut() - .push(self.scalar_modulus) - .sub() - .push(ptr) - .mstore(); - - self.scalar(Value::Memory(ptr)) + self.scalar(Value::Negated(Box::new(scalar.value.clone()))) } } @@ -552,19 +556,15 @@ pub struct EcPoint { } impl EcPoint { - pub(super) fn loader(&self) -> &Rc { + pub(crate) fn loader(&self) -> &Rc { &self.loader } - pub fn value(&self) -> Value<(U256, U256)> { - self.value - } - - pub fn is_const(&self) -> bool { - matches!(self.value, Value::Constant(_)) + pub(crate) fn value(&self) -> Value<(U256, U256)> { + self.value.clone() } - pub fn ptr(&self) -> usize { + pub(crate) fn ptr(&self) -> usize { match self.value { Value::Memory(ptr) => ptr, _ => unreachable!(), @@ -580,70 +580,6 @@ impl Debug for EcPoint { } } -impl Add for EcPoint { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - self.loader.ec_point_add(&self, &rhs) - } -} - -impl Sub for EcPoint { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - self.loader.ec_point_sub(&self, &rhs) - } -} - -impl Neg for EcPoint { - type Output = Self; - - fn neg(self) -> Self { - self.loader.ec_point_neg(&self) - } -} - -impl<'a> Add<&'a Self> for EcPoint { - type Output = Self; - - fn add(self, rhs: &'a Self) -> Self { - self.loader.ec_point_add(&self, rhs) - } -} - -impl<'a> Sub<&'a Self> for EcPoint { - type Output = Self; - - fn sub(self, rhs: &'a Self) -> Self { - self.loader.ec_point_sub(&self, rhs) - } -} - -impl AddAssign for EcPoint { - fn add_assign(&mut self, rhs: Self) { - *self = self.loader.ec_point_add(self, &rhs); - } -} - -impl SubAssign for EcPoint { - fn sub_assign(&mut self, rhs: Self) { - *self = self.loader.ec_point_sub(self, &rhs); - } -} - -impl<'a> AddAssign<&'a Self> for EcPoint { - fn add_assign(&mut self, rhs: &'a Self) { - *self = self.loader.ec_point_add(self, rhs); - } -} - -impl<'a> SubAssign<&'a Self> for EcPoint { - fn sub_assign(&mut self, rhs: &'a Self) { - *self = self.loader.ec_point_sub(self, rhs); - } -} - impl PartialEq for EcPoint { fn eq(&self, other: &Self) -> bool { self.value == other.value @@ -652,8 +588,8 @@ impl PartialEq for EcPoint { impl LoadedEcPoint for EcPoint where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, + C: CurveAffine, + C::ScalarExt: PrimeField, { type Loader = Rc; @@ -668,7 +604,7 @@ where Value::Constant(constant) if constant == U256::one() => ec_point, _ => ec_point.loader.ec_point_scalar_mul(&ec_point, &scalar), }) - .reduce(|acc, ec_point| acc + ec_point) + .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point)) .unwrap() } } @@ -680,18 +616,27 @@ pub struct Scalar { } impl Scalar { - pub fn value(&self) -> Value { - self.value + pub(crate) fn loader(&self) -> &Rc { + &self.loader } - pub fn is_const(&self) -> bool { + pub(crate) fn value(&self) -> Value { + self.value.clone() + } + + pub(crate) fn is_const(&self) -> bool { matches!(self.value, Value::Constant(_)) } - pub fn ptr(&self) -> usize { + pub(crate) fn ptr(&self) -> usize { match self.value { Value::Memory(ptr) => ptr, - _ => unreachable!(), + _ => *self + .loader + .cache + .borrow() + .get(&self.value.identifier()) + .unwrap(), } } } @@ -879,34 +824,164 @@ impl> LoadedScalar for Scalar { impl EcPointLoader for Rc where - C: Curve + UncompressedEncoding, + C: CurveAffine, C::Scalar: PrimeField, { type LoadedEcPoint = EcPoint; fn ec_point_load_const(&self, value: &C) -> EcPoint { - let bytes = value.to_uncompressed(); - let (x, y) = ( - U256::from_little_endian(&bytes[..32]), - U256::from_little_endian(&bytes[32..]), - ); + let coordinates = value.coordinates().unwrap(); + let [x, y] = [coordinates.x(), coordinates.y()] + .map(|coordinate| U256::from_little_endian(coordinate.to_repr().as_ref())); self.ec_point(Value::Constant((x, y))) } + + fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) -> Result<(), Error> { + unimplemented!() + } } impl> ScalarLoader for Rc { type LoadedScalar = Scalar; fn load_const(&self, value: &F) -> Scalar { - self.scalar(Value::Constant(U256::from_little_endian( - value.to_repr().as_slice(), - ))) + self.scalar(Value::Constant(fe_to_u256(*value))) + } + + fn assert_eq(&self, _: &str, _: &Scalar, _: &Scalar) -> Result<(), Error> { + unimplemented!() + } + + fn sum_with_coeff_and_const(&self, values: &[(F, &Scalar)], constant: F) -> Scalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let push_addend = |(coeff, value): &(F, &Scalar)| { + assert_ne!(*coeff, F::zero()); + match (*coeff == F::one(), &value.value) { + (true, _) => { + self.push(value); + } + (false, Value::Constant(value)) => { + self.push(&self.scalar(Value::Constant(fe_to_u256( + *coeff * u256_to_fe::(*value), + )))); + } + (false, _) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + self.push(value); + self.code.borrow_mut().mulmod(); + } + } + }; + + let mut values = values.iter(); + if constant == F::zero() { + push_addend(values.next().unwrap()); + } else { + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); + } + + let chunk_size = 16 - self.code.borrow().stack_len(); + for values in &values.chunks(chunk_size) { + let values = values.into_iter().collect_vec(); + + self.code.borrow_mut().push(self.scalar_modulus); + for _ in 1..chunk_size.min(values.len()) { + self.code.borrow_mut().dup(0); + } + self.code.borrow_mut().swap(chunk_size.min(values.len())); + + for value in values { + push_addend(value); + self.code.borrow_mut().addmod(); + } + } + + let ptr = self.allocate(0x20); + self.code.borrow_mut().push(ptr).mstore(); + + self.scalar(Value::Memory(ptr)) + } + + fn sum_products_with_coeff_and_const( + &self, + values: &[(F, &Scalar, &Scalar)], + constant: F, + ) -> Scalar { + if values.is_empty() { + return self.load_const(&constant); + } + + let push_addend = |(coeff, lhs, rhs): &(F, &Scalar, &Scalar)| { + assert_ne!(*coeff, F::zero()); + match (*coeff == F::one(), &lhs.value, &rhs.value) { + (_, Value::Constant(lhs), Value::Constant(rhs)) => { + self.push(&self.scalar(Value::Constant(fe_to_u256( + *coeff * u256_to_fe::(*lhs) * u256_to_fe::(*rhs), + )))); + } + (_, value @ Value::Memory(_), Value::Constant(constant)) + | (_, Value::Constant(constant), value @ Value::Memory(_)) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(&self.scalar(Value::Constant(fe_to_u256( + *coeff * u256_to_fe::(*constant), + )))); + self.push(&self.scalar(value.clone())); + self.code.borrow_mut().mulmod(); + } + (true, _, _) => { + self.code.borrow_mut().push(self.scalar_modulus); + self.push(lhs); + self.push(rhs); + self.code.borrow_mut().mulmod(); + } + (false, _, _) => { + self.code.borrow_mut().push(self.scalar_modulus).dup(0); + self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + self.push(lhs); + self.code.borrow_mut().mulmod(); + self.push(rhs); + self.code.borrow_mut().mulmod(); + } + } + }; + + let mut values = values.iter(); + if constant == F::zero() { + push_addend(values.next().unwrap()); + } else { + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); + } + + let chunk_size = 16 - self.code.borrow().stack_len(); + for values in &values.chunks(chunk_size) { + let values = values.into_iter().collect_vec(); + + self.code.borrow_mut().push(self.scalar_modulus); + for _ in 1..chunk_size.min(values.len()) { + self.code.borrow_mut().dup(0); + } + self.code.borrow_mut().swap(chunk_size.min(values.len())); + + for value in values { + push_addend(value); + self.code.borrow_mut().addmod(); + } + } + + let ptr = self.allocate(0x20); + self.code.borrow_mut().push(ptr).mstore(); + + self.scalar(Value::Memory(ptr)) } } impl Loader for Rc where - C: Curve + UncompressedEncoding, + C: CurveAffine, C::Scalar: PrimeField, { #[cfg(test)] diff --git a/src/loader/evm/test.rs b/src/loader/evm/test.rs index d4ae4984..e204c1b8 100644 --- a/src/loader/evm/test.rs +++ b/src/loader/evm/test.rs @@ -9,12 +9,6 @@ use std::env::var_os; mod tui; -fn small_address(lsb: u8) -> Address { - let mut address = Address::zero(); - *address.0.last_mut().unwrap() = lsb; - address -} - fn debug() -> bool { matches!( var_os("DEBUG"), @@ -23,9 +17,15 @@ fn debug() -> bool { } pub fn execute(code: Vec, calldata: Vec) -> (bool, u64, Vec) { + assert!( + code.len() <= 0x6000, + "Contract size {} exceeds the limit 24576", + code.len() + ); + let debug = debug(); - let caller = small_address(0xfe); - let callee = small_address(0xff); + let caller = Address::from_low_u64_be(0xfe); + let callee = Address::from_low_u64_be(0xff); let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) @@ -52,5 +52,5 @@ pub fn execute(code: Vec, calldata: Vec) -> (bool, u64, Vec) { Tui::new(result.debug.unwrap().flatten(0), 0).start(); } - (!result.reverted, result.gas, costs) + (!result.reverted, result.gas_used, costs) } diff --git a/src/loader/evm/test/tui.rs b/src/loader/evm/test/tui.rs index 72866d19..fcaef36c 100644 --- a/src/loader/evm/test/tui.rs +++ b/src/loader/evm/test/tui.rs @@ -60,7 +60,7 @@ impl Tui { .expect("unable to execute disable mouse capture"); println!("{e}"); })); - let tick_rate = Duration::from_millis(200); + let tick_rate = Duration::from_millis(60); let (tx, rx) = mpsc::channel(); thread::spawn(move || { diff --git a/src/loader/evm/transcript.rs b/src/loader/evm/transcript.rs deleted file mode 100644 index a5ff68d2..00000000 --- a/src/loader/evm/transcript.rs +++ /dev/null @@ -1,229 +0,0 @@ -use crate::{ - loader::{ - evm::{ - loader::{EcPoint, EvmLoader, Scalar, Value}, - u256_to_field, - }, - native::NativeLoader, - Loader, - }, - util::{Curve, Group, Itertools, PrimeField, Transcript, TranscriptRead, UncompressedEncoding}, - Error, -}; -use ethereum_types::U256; -use sha3::{Digest, Keccak256}; -use std::{ - io::{self, Read, Write}, - marker::PhantomData, - rc::Rc, -}; - -pub struct MemoryChunk { - ptr: usize, - len: usize, -} - -impl MemoryChunk { - fn new(ptr: usize) -> Self { - Self { ptr, len: 0x20 } - } - - fn reset(&mut self, ptr: usize) { - self.ptr = ptr; - self.len = 0x20; - } - - fn include(&self, ptr: usize, size: usize) -> bool { - let range = self.ptr..=self.ptr + self.len; - range.contains(&ptr) && range.contains(&(ptr + size)) - } - - fn extend(&mut self, ptr: usize, size: usize) { - if !self.include(ptr, size) { - assert_eq!(self.ptr + self.len, ptr); - self.len += size; - } - } -} - -pub struct EvmTranscript, S, B> { - loader: L, - stream: S, - buf: B, - _marker: PhantomData, -} - -impl EvmTranscript, usize, MemoryChunk> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - pub fn new(loader: Rc) -> Self { - let ptr = loader.allocate(0x20); - assert_eq!(ptr, 0); - Self { - loader, - stream: 0, - buf: MemoryChunk::new(ptr), - _marker: PhantomData, - } - } -} - -impl Transcript> for EvmTranscript, usize, MemoryChunk> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn squeeze_challenge(&mut self) -> Scalar { - let (ptr, scalar) = self.loader.squeeze_challenge(self.buf.ptr, self.buf.len); - self.buf.reset(ptr); - scalar - } - - fn common_ec_point(&mut self, ec_point: &EcPoint) -> Result<(), Error> { - if let Value::Memory(ptr) = ec_point.value() { - self.buf.extend(ptr, 0x40); - } else { - unreachable!() - } - Ok(()) - } - - fn common_scalar(&mut self, scalar: &Scalar) -> Result<(), Error> { - match scalar.value() { - Value::Constant(_) if self.buf.ptr == 0 => { - self.loader.copy_scalar(scalar, self.buf.ptr); - } - Value::Memory(ptr) => { - self.buf.extend(ptr, 0x20); - } - _ => unreachable!(), - } - Ok(()) - } -} - -impl TranscriptRead> for EvmTranscript, usize, MemoryChunk> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn read_scalar(&mut self) -> Result { - let scalar = self.loader.calldataload_scalar(self.stream); - self.stream += 0x20; - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result { - let ec_point = self.loader.calldataload_ec_point(self.stream); - self.stream += 0x40; - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl EvmTranscript> -where - C: Curve, -{ - pub fn new(stream: S) -> Self { - Self { - loader: NativeLoader, - stream, - buf: Vec::new(), - _marker: PhantomData, - } - } -} - -impl Transcript for EvmTranscript> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn squeeze_challenge(&mut self) -> C::Scalar { - let data = self - .buf - .iter() - .cloned() - .chain(if self.buf.len() == 0x20 { - Some(1) - } else { - None - }) - .collect_vec(); - let hash: [u8; 32] = Keccak256::digest(data).into(); - self.buf = hash.to_vec(); - u256_to_field(U256::from_big_endian(hash.as_slice())) - } - - fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { - let uncopressed = ec_point.to_uncompressed(); - self.buf.extend(uncopressed[..32].iter().rev().cloned()); - self.buf.extend(uncopressed[32..].iter().rev().cloned()); - - Ok(()) - } - - fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { - self.buf.extend(scalar.to_repr().as_ref().iter().rev()); - - Ok(()) - } -} - -impl TranscriptRead for EvmTranscript> -where - C: Curve + UncompressedEncoding, - C::Scalar: PrimeField, - S: Read, -{ - fn read_scalar(&mut self) -> Result { - let mut data = [0; 32]; - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - data.reverse(); - let scalar = ::Scalar::from_repr_vartime(data).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid scalar encoding in proof".to_string(), - ) - })?; - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result { - let mut data = [0; 64]; - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - data.as_mut_slice()[..32].reverse(); - data.as_mut_slice()[32..].reverse(); - let ec_point = C::from_uncompressed(data).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid elliptic curve point encoding in proof".to_string(), - ) - })?; - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl EvmTranscript> -where - C: Curve, - S: Write, -{ - pub fn stream_mut(&mut self) -> &mut S { - &mut self.stream - } - - pub fn finalize(self) -> S { - self.stream - } -} diff --git a/src/loader/evm/util.rs b/src/loader/evm/util.rs new file mode 100644 index 00000000..0d9698bd --- /dev/null +++ b/src/loader/evm/util.rs @@ -0,0 +1,92 @@ +use crate::{ + cost::Cost, + util::{arithmetic::PrimeField, Itertools}, +}; +use ethereum_types::U256; +use std::iter; + +pub struct MemoryChunk { + ptr: usize, + len: usize, +} + +impl MemoryChunk { + pub fn new(ptr: usize) -> Self { + Self { ptr, len: 0 } + } + + pub fn ptr(&self) -> usize { + self.ptr + } + + pub fn len(&self) -> usize { + self.len + } + + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + pub fn end(&self) -> usize { + self.ptr + self.len + } + + pub fn reset(&mut self, ptr: usize) { + self.ptr = ptr; + self.len = 0; + } + + pub fn extend(&mut self, size: usize) { + self.len += size; + } +} + +// Assume fields implements traits in crate `ff` always have little-endian representation. +pub fn fe_to_u256(f: F) -> U256 +where + F: PrimeField, +{ + U256::from_little_endian(f.to_repr().as_ref()) +} + +pub fn u256_to_fe(value: U256) -> F +where + F: PrimeField, +{ + let value = value % modulus::(); + let mut repr = F::Repr::default(); + value.to_little_endian(repr.as_mut()); + F::from_repr(repr).unwrap() +} + +pub fn modulus() -> U256 +where + F: PrimeField, +{ + U256::from_little_endian((-F::one()).to_repr().as_ref()) + 1 +} + +pub fn encode_calldata(instances: &[Vec], proof: &[u8]) -> Vec +where + F: PrimeField, +{ + iter::empty() + .chain( + instances + .iter() + .flatten() + .flat_map(|value| value.to_repr().as_ref().iter().rev().cloned().collect_vec()), + ) + .chain(proof.iter().cloned()) + .collect() +} + +pub fn estimate_gas(cost: Cost) -> usize { + let proof_size = cost.num_commitment * 64 + (cost.num_evaluation + cost.num_instance) * 32; + + let intrinsic_cost = 21000; + let calldata_cost = (proof_size as f64 * 15.25).ceil() as usize; + let ec_operation_cost = 113100 + (cost.num_msm - 2) * 6350; + + intrinsic_cost + calldata_cost + ec_operation_cost +} diff --git a/src/loader/halo2.rs b/src/loader/halo2.rs index 3d227a8a..9eae74e3 100644 --- a/src/loader/halo2.rs +++ b/src/loader/halo2.rs @@ -1,6 +1,29 @@ -mod accumulation; -mod loader; -mod transcript; +pub(crate) mod loader; +mod shim; -pub use loader::Halo2Loader; -pub use transcript::PoseidonTranscript; +#[cfg(test)] +pub(crate) mod test; + +pub use loader::{EcPoint, Halo2Loader, Scalar}; +pub use shim::{Context, EccInstructions, IntegerInstructions}; +pub use util::Valuetools; + +pub use halo2_wrong_ecc; + +mod util { + use halo2_proofs::circuit::Value; + + pub trait Valuetools: Iterator> { + fn fold_zipped(self, init: B, mut f: F) -> Value + where + Self: Sized, + F: FnMut(B, V) -> B, + { + self.into_iter().fold(Value::known(init), |acc, value| { + acc.zip(value).map(|(acc, value)| f(acc, value)) + }) + } + } + + impl>> Valuetools for I {} +} diff --git a/src/loader/halo2/accumulation.rs b/src/loader/halo2/accumulation.rs deleted file mode 100644 index cc78ab83..00000000 --- a/src/loader/halo2/accumulation.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::{ - loader::{ - halo2::loader::{Halo2Loader, Scalar}, - LoadedEcPoint, - }, - protocol::Protocol, - scheme::kzg::{AccumulationStrategy, Accumulator, SameCurveAccumulation, MSM}, - util::{Itertools, Transcript}, - Error, -}; -use halo2_curves::CurveAffine; -use halo2_wrong_ecc::AssignedPoint; -use std::rc::Rc; - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> - SameCurveAccumulation>, LIMBS, BITS> -{ - pub fn finalize( - self, - g1: C, - ) -> ( - AssignedPoint, - AssignedPoint, - ) { - let (lhs, rhs) = self.accumulator.unwrap().evaluate(g1.to_curve()); - let loader = lhs.loader(); - ( - loader.ec_point_nomalize(&lhs.assigned()), - loader.ec_point_nomalize(&rhs.assigned()), - ) - } -} - -impl<'a, 'b, C, T, P, const LIMBS: usize, const BITS: usize> - AccumulationStrategy>, T, P> - for SameCurveAccumulation>, LIMBS, BITS> -where - C: CurveAffine, - T: Transcript>>, -{ - type Output = (); - - fn extract_accumulator( - &self, - protocol: &Protocol, - loader: &Rc>, - transcript: &mut T, - statements: &[Vec>], - ) -> Option>>> { - let accumulator_indices = protocol.accumulator_indices.as_ref()?; - - let challenges = transcript.squeeze_n_challenges(accumulator_indices.len()); - let accumulators = accumulator_indices - .iter() - .map(|indices| { - assert_eq!(indices.len(), 4 * LIMBS); - let assinged = indices - .iter() - .map(|index| statements[index.0][index.1].assigned()) - .collect_vec(); - let lhs = loader.assign_ec_point_from_limbs( - assinged[..LIMBS].to_vec().try_into().unwrap(), - assinged[LIMBS..2 * LIMBS].to_vec().try_into().unwrap(), - ); - let rhs = loader.assign_ec_point_from_limbs( - assinged[2 * LIMBS..3 * LIMBS].to_vec().try_into().unwrap(), - assinged[3 * LIMBS..].to_vec().try_into().unwrap(), - ); - Accumulator::new(MSM::base(lhs), MSM::base(rhs)) - }) - .collect_vec(); - - Some(Accumulator::random_linear_combine( - challenges.into_iter().zip(accumulators), - )) - } - - fn process( - &mut self, - _: &Rc>, - transcript: &mut T, - _: P, - accumulator: Accumulator>>, - ) -> Result { - self.accumulator = Some(match self.accumulator.take() { - Some(curr_accumulator) => { - accumulator + curr_accumulator * &transcript.squeeze_challenge() - } - None => accumulator, - }); - Ok(()) - } -} diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index e811debb..0d288ce7 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -1,184 +1,133 @@ use crate::{ - loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::{Curve, Field, FieldOps, Group, Itertools}, -}; -use halo2_curves::CurveAffine; -use halo2_proofs::circuit; -use halo2_wrong_ecc::{ - integer::{ - rns::{Integer, Rns}, - IntegerInstructions, Range, + loader::{ + halo2::shim::{EccInstructions, IntegerInstructions}, + EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader, }, - maingate::{ - AssignedValue, CombinationOptionCommon, MainGate, MainGateInstructions, RegionCtx, Term, + util::{ + arithmetic::{CurveAffine, Field, FieldOps}, + Itertools, }, - AssignedPoint, BaseFieldEccChip, EccConfig, }; -use rand::rngs::OsRng; +use halo2_proofs::circuit; use std::{ - cell::RefCell, + cell::{Ref, RefCell, RefMut}, + collections::btree_map::{BTreeMap, Entry}, fmt::{self, Debug}, iter, - ops::{Add, AddAssign, Deref, DerefMut, Mul, MulAssign, Neg, Sub, SubAssign}, + marker::PhantomData, + ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Sub, SubAssign}, rc::Rc, }; -const WINDOW_SIZE: usize = 3; - -#[derive(Clone, Debug)] -pub enum Value { - Constant(T), - Assigned(L), -} - -pub struct Halo2Loader<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> { - rns: Rc>, - ecc_chip: RefCell>, - main_gate: MainGate, - ctx: RefCell>, +#[derive(Debug)] +pub struct Halo2Loader<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { + ecc_chip: RefCell, + ctx: RefCell, + num_scalar: RefCell, num_ec_point: RefCell, + const_ec_point: RefCell>>, + _marker: PhantomData, #[cfg(test)] row_meterings: RefCell>, } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> - Halo2Loader<'a, 'b, C, LIMBS, BITS> -{ - pub fn new(ecc_config: EccConfig, ctx: RegionCtx<'a, 'b, C::Scalar>) -> Rc { - let ecc_chip = BaseFieldEccChip::new(ecc_config); - let main_gate = ecc_chip.main_gate(); +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, EccChip> { + pub fn new(ecc_chip: EccChip, ctx: EccChip::Context) -> Rc { Rc::new(Self { - rns: Rc::new(Rns::construct()), ecc_chip: RefCell::new(ecc_chip), - main_gate, ctx: RefCell::new(ctx), - num_ec_point: RefCell::new(0), + num_scalar: RefCell::default(), + num_ec_point: RefCell::default(), + const_ec_point: RefCell::default(), #[cfg(test)] - row_meterings: RefCell::new(Vec::new()), + row_meterings: RefCell::default(), + _marker: PhantomData, }) } - pub fn rns(&self) -> Rc> { - self.rns.clone() + pub fn into_ctx(self) -> EccChip::Context { + self.ctx.into_inner() } - pub fn ecc_chip(&self) -> impl Deref> + '_ { + pub fn ecc_chip(&self) -> Ref<'_, EccChip> { self.ecc_chip.borrow() } - pub(super) fn ctx_mut(&self) -> impl DerefMut> + '_ { + pub fn scalar_chip(&self) -> Ref<'_, EccChip::ScalarChip> { + Ref::map(self.ecc_chip(), |ecc_chip| ecc_chip.scalar_chip()) + } + + pub fn ctx(&self) -> Ref<'_, EccChip::Context> { + self.ctx.borrow() + } + + pub(crate) fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { self.ctx.borrow_mut() } - pub fn assign_const_scalar( - self: &Rc, - scalar: C::Scalar, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + pub fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> Scalar<'a, C, EccChip> { let assigned = self - .main_gate - .assign_constant(&mut self.ctx_mut(), scalar) + .scalar_chip() + .assign_constant(&mut self.ctx_mut(), constant) .unwrap(); self.scalar(Value::Assigned(assigned)) } pub fn assign_scalar( self: &Rc, - scalar: circuit::Value, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + scalar: circuit::Value, + ) -> Scalar<'a, C, EccChip> { let assigned = self - .main_gate - .assign_value(&mut self.ctx_mut(), scalar) + .scalar_chip() + .assign_integer(&mut self.ctx_mut(), scalar) .unwrap(); self.scalar(Value::Assigned(assigned)) } - pub fn scalar( + pub(crate) fn scalar( self: &Rc, - value: Value>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + value: Value, + ) -> Scalar<'a, C, EccChip> { + let index = *self.num_scalar.borrow(); + *self.num_scalar.borrow_mut() += 1; Scalar { loader: self.clone(), + index, value, } } - pub fn assign_const_ec_point(self: &Rc, ec_point: C) -> EcPoint<'a, 'b, C, LIMBS, BITS> { - let assigned = self - .ecc_chip - .borrow() - .assign_constant(&mut self.ctx_mut(), ec_point) - .unwrap(); - self.ec_point(assigned) + pub fn assign_const_ec_point(self: &Rc, constant: C) -> EcPoint<'a, C, EccChip> { + let coordinates = constant.coordinates().unwrap(); + match self + .const_ec_point + .borrow_mut() + .entry((*coordinates.x(), *coordinates.y())) + { + Entry::Occupied(entry) => entry.get().clone(), + Entry::Vacant(entry) => { + let assigned = self + .ecc_chip() + .assign_point(&mut self.ctx_mut(), circuit::Value::known(constant)) + .unwrap(); + let ec_point = self.ec_point(assigned); + entry.insert(ec_point).clone() + } + } } pub fn assign_ec_point( self: &Rc, ec_point: circuit::Value, - ) -> EcPoint<'a, 'b, C, LIMBS, BITS> { + ) -> EcPoint<'a, C, EccChip> { let assigned = self - .ecc_chip - .borrow() + .ecc_chip() .assign_point(&mut self.ctx_mut(), ec_point) .unwrap(); self.ec_point(assigned) } - pub fn assign_ec_point_from_limbs( - self: &Rc, - x_limbs: [AssignedValue; LIMBS], - y_limbs: [AssignedValue; LIMBS], - ) -> EcPoint<'a, 'b, C, LIMBS, BITS> { - let [x, y] = [&x_limbs, &y_limbs] - .map(|limbs| { - limbs.iter().enumerate().fold( - circuit::Value::known([C::Scalar::zero(); LIMBS]), - |acc, (idx, limb)| { - acc.zip(limb.value()).map(|(mut acc, limb)| { - acc[idx] = *limb; - acc - }) - }, - ) - }) - .map(|limbs| { - self.ecc_chip - .borrow() - .integer_chip() - .assign_integer( - &mut self.ctx_mut(), - limbs - .map(|limbs| Integer::from_limbs(&limbs, self.rns.clone())) - .into(), - Range::Remainder, - ) - .unwrap() - }); - - let ec_point = AssignedPoint::new(x, y); - self.ecc_chip() - .assert_is_on_curve(&mut self.ctx_mut(), &ec_point) - .unwrap(); - - for (src, dst) in x_limbs.iter().chain(y_limbs.iter()).zip( - ec_point - .get_x() - .limbs() - .iter() - .chain(ec_point.get_y().limbs().iter()), - ) { - self.ctx - .borrow_mut() - .constrain_equal(src.cell(), dst.as_ref().cell()) - .unwrap(); - } - - self.ec_point(ec_point) - } - - pub fn ec_point( - self: &Rc, - assigned: AssignedPoint, - ) -> EcPoint<'a, 'b, C, LIMBS, BITS> { + fn ec_point(self: &Rc, assigned: EccChip::AssignedEcPoint) -> EcPoint<'a, C, EccChip> { let index = *self.num_ec_point.borrow(); *self.num_ec_point.borrow_mut() += 1; EcPoint { @@ -188,71 +137,66 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> } } - pub fn ec_point_nomalize( - self: &Rc, - assigned: &AssignedPoint, - ) -> AssignedPoint { - self.ecc_chip() - .normalize(&mut self.ctx_mut(), assigned) - .unwrap() - } - fn add( self: &Rc, - lhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - rhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Scalar<'a, C, EccChip> { let output = match (&lhs.value, &rhs.value) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs + rhs), (Value::Assigned(assigned), Value::Constant(constant)) - | (Value::Constant(constant), Value::Assigned(assigned)) => { - MainGateInstructions::add_constant( - &self.main_gate, + | (Value::Constant(constant), Value::Assigned(assigned)) => self + .scalar_chip() + .sum_with_coeff_and_const( &mut self.ctx_mut(), - assigned, + &[(C::Scalar::one(), assigned.clone())], *constant, ) .map(Value::Assigned) - .unwrap() - } - (Value::Assigned(lhs), Value::Assigned(rhs)) => { - MainGateInstructions::add(&self.main_gate, &mut self.ctx_mut(), lhs, rhs) - .map(Value::Assigned) - .unwrap() - } + .unwrap(), + (Value::Assigned(lhs), Value::Assigned(rhs)) => self + .scalar_chip() + .sum_with_coeff_and_const( + &mut self.ctx_mut(), + &[ + (C::Scalar::one(), lhs.clone()), + (C::Scalar::one(), rhs.clone()), + ], + C::Scalar::zero(), + ) + .map(Value::Assigned) + .unwrap(), }; self.scalar(output) } fn sub( self: &Rc, - lhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - rhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Scalar<'a, C, EccChip> { let output = match (&lhs.value, &rhs.value) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs - rhs), - (Value::Constant(constant), Value::Assigned(assigned)) => { - MainGateInstructions::neg_with_constant( - &self.main_gate, + (Value::Constant(constant), Value::Assigned(assigned)) => self + .scalar_chip() + .sum_with_coeff_and_const( &mut self.ctx_mut(), - assigned, + &[(-C::Scalar::one(), assigned.clone())], *constant, ) .map(Value::Assigned) - .unwrap() - } - (Value::Assigned(assigned), Value::Constant(constant)) => { - MainGateInstructions::add_constant( - &self.main_gate, + .unwrap(), + (Value::Assigned(assigned), Value::Constant(constant)) => self + .scalar_chip() + .sum_with_coeff_and_const( &mut self.ctx_mut(), - assigned, - constant.neg(), + &[(C::Scalar::one(), assigned.clone())], + -*constant, ) .map(Value::Assigned) - .unwrap() - } + .unwrap(), (Value::Assigned(lhs), Value::Assigned(rhs)) => { - MainGateInstructions::sub(&self.main_gate, &mut self.ctx_mut(), lhs, rhs) + IntegerInstructions::sub(self.scalar_chip().deref(), &mut self.ctx_mut(), lhs, rhs) .map(Value::Assigned) .unwrap() } @@ -262,89 +206,78 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> fn mul( self: &Rc, - lhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - rhs: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Scalar<'a, C, EccChip> { let output = match (&lhs.value, &rhs.value) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs * rhs), (Value::Assigned(assigned), Value::Constant(constant)) - | (Value::Constant(constant), Value::Assigned(assigned)) => { - MainGateInstructions::apply( - &self.main_gate, + | (Value::Constant(constant), Value::Assigned(assigned)) => self + .scalar_chip() + .sum_with_coeff_and_const( &mut self.ctx_mut(), - [ - Term::Assigned(assigned, *constant), - Term::unassigned_to_sub( - assigned.value().map(|assigned| *assigned * constant), - ), - ], + &[(*constant, assigned.clone())], C::Scalar::zero(), - CombinationOptionCommon::OneLinerAdd.into(), ) - .map(|mut assigned| Value::Assigned(assigned.swap_remove(1))) - .unwrap() - } - (Value::Assigned(lhs), Value::Assigned(rhs)) => { - MainGateInstructions::mul(&self.main_gate, &mut self.ctx_mut(), lhs, rhs) - .map(Value::Assigned) - .unwrap() - } + .map(Value::Assigned) + .unwrap(), + (Value::Assigned(lhs), Value::Assigned(rhs)) => self + .scalar_chip() + .sum_products_with_coeff_and_const( + &mut self.ctx_mut(), + &[(C::Scalar::one(), lhs.clone(), rhs.clone())], + C::Scalar::zero(), + ) + .map(Value::Assigned) + .unwrap(), }; self.scalar(output) } - fn neg( - self: &Rc, - scalar: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + fn neg(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { let output = match &scalar.value { Value::Constant(constant) => Value::Constant(constant.neg()), - Value::Assigned(assigned) => MainGateInstructions::neg_with_constant( - &self.main_gate, - &mut self.ctx_mut(), - assigned, - C::Scalar::zero(), - ) - .map(Value::Assigned) - .unwrap(), + Value::Assigned(assigned) => { + IntegerInstructions::neg(self.scalar_chip().deref(), &mut self.ctx_mut(), assigned) + .map(Value::Assigned) + .unwrap() + } }; self.scalar(output) } - fn invert( - self: &Rc, - scalar: &Scalar<'a, 'b, C, LIMBS, BITS>, - ) -> Scalar<'a, 'b, C, LIMBS, BITS> { + fn invert(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { let output = match &scalar.value { Value::Constant(constant) => Value::Constant(Field::invert(constant).unwrap()), - Value::Assigned(assigned) => { - let (inv, non_invertable) = - MainGateInstructions::invert(&self.main_gate, &mut self.ctx_mut(), assigned) - .unwrap(); - self.main_gate - .assert_zero(&mut self.ctx_mut(), &non_invertable) - .unwrap(); - Value::Assigned(inv) - } + Value::Assigned(assigned) => Value::Assigned( + IntegerInstructions::invert( + self.scalar_chip().deref(), + &mut self.ctx_mut(), + assigned, + ) + .unwrap(), + ), }; self.scalar(output) } } #[cfg(test)] -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> - Halo2Loader<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, EccChip> { fn start_row_metering(self: &Rc, identifier: &str) { + use crate::loader::halo2::shim::Context; + self.row_meterings .borrow_mut() - .push((identifier.to_string(), *self.ctx.borrow().offset)) + .push((identifier.to_string(), self.ctx().offset())) } fn end_row_metering(self: &Rc) { + use crate::loader::halo2::shim::Context; + let mut row_meterings = self.row_meterings.borrow_mut(); let (_, row) = row_meterings.last_mut().unwrap(); - *row = *self.ctx.borrow().offset - *row; + *row = self.ctx().offset() - *row; } pub fn print_row_metering(self: &Rc) { @@ -354,14 +287,25 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> } } +#[derive(Clone, Debug)] +pub enum Value { + Constant(T), + Assigned(L), +} + #[derive(Clone)] -pub struct Scalar<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> { - loader: Rc>, - value: Value>, +pub struct Scalar<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { + loader: Rc>, + index: usize, + value: Value, } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Scalar<'a, 'b, C, LIMBS, BITS> { - pub fn assigned(&self) -> AssignedValue { +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> { + pub fn loader(&self) -> &Rc> { + &self.loader + } + + pub(crate) fn assigned(&self) -> EccChip::AssignedScalar { match &self.value { Value::Constant(constant) => self.loader.assign_const_scalar(*constant).assigned(), Value::Assigned(assigned) => assigned.clone(), @@ -369,19 +313,23 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Scalar<'a, ' } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> LoadedScalar - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for Scalar<'a, C, EccChip> { + fn eq(&self, other: &Self) -> bool { + self.index == other.index + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedScalar + for Scalar<'a, C, EccChip> { - type Loader = Rc>; + type Loader = Rc>; fn loader(&self) -> &Self::Loader { &self.loader } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Debug - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for Scalar<'a, C, EccChip> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Scalar") .field("value", &self.value) @@ -389,166 +337,146 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Debug } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> FieldOps - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> FieldOps for Scalar<'a, C, EccChip> { fn invert(&self) -> Option { - Some((&self.loader).invert(self)) + Some(self.loader.invert(self)) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Add - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add for Scalar<'a, C, EccChip> { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - (&self.loader).add(&self, &rhs) + Halo2Loader::add(&self.loader, &self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Sub - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub for Scalar<'a, C, EccChip> { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - (&self.loader).sub(&self, &rhs) + Halo2Loader::sub(&self.loader, &self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Mul - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul for Scalar<'a, C, EccChip> { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - (&self.loader).mul(&self, &rhs) + Halo2Loader::mul(&self.loader, &self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Neg - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Neg for Scalar<'a, C, EccChip> { type Output = Self; fn neg(self) -> Self::Output { - (&self.loader).neg(&self) + Halo2Loader::neg(&self.loader, &self) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Add<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Add<&'b Self> + for Scalar<'a, C, EccChip> { type Output = Self; - fn add(self, rhs: &'c Self) -> Self::Output { - (&self.loader).add(&self, rhs) + fn add(self, rhs: &'b Self) -> Self::Output { + Halo2Loader::add(&self.loader, &self, rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Sub<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Sub<&'b Self> + for Scalar<'a, C, EccChip> { type Output = Self; - fn sub(self, rhs: &'c Self) -> Self::Output { - (&self.loader).sub(&self, rhs) + fn sub(self, rhs: &'b Self) -> Self::Output { + Halo2Loader::sub(&self.loader, &self, rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Mul<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> Mul<&'b Self> + for Scalar<'a, C, EccChip> { type Output = Self; - fn mul(self, rhs: &'c Self) -> Self::Output { - (&self.loader).mul(&self, rhs) + fn mul(self, rhs: &'b Self) -> Self::Output { + Halo2Loader::mul(&self.loader, &self, rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> AddAssign - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> AddAssign for Scalar<'a, C, EccChip> { fn add_assign(&mut self, rhs: Self) { - *self = (&self.loader).add(self, &rhs) + *self = Halo2Loader::add(&self.loader, self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> SubAssign - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> SubAssign for Scalar<'a, C, EccChip> { fn sub_assign(&mut self, rhs: Self) { - *self = (&self.loader).sub(self, &rhs) + *self = Halo2Loader::sub(&self.loader, self, &rhs) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> MulAssign - for Scalar<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign for Scalar<'a, C, EccChip> { fn mul_assign(&mut self, rhs: Self) { - *self = (&self.loader).mul(self, &rhs) + *self = Halo2Loader::mul(&self.loader, self, &rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> AddAssign<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> AddAssign<&'b Self> + for Scalar<'a, C, EccChip> { - fn add_assign(&mut self, rhs: &'c Self) { - *self = (&self.loader).add(self, rhs) + fn add_assign(&mut self, rhs: &'b Self) { + *self = Halo2Loader::add(&self.loader, self, rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> SubAssign<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> SubAssign<&'b Self> + for Scalar<'a, C, EccChip> { - fn sub_assign(&mut self, rhs: &'c Self) { - *self = (&self.loader).sub(self, rhs) + fn sub_assign(&mut self, rhs: &'b Self) { + *self = Halo2Loader::sub(&self.loader, self, rhs) } } -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> MulAssign<&'c Self> - for Scalar<'a, 'b, C, LIMBS, BITS> +impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign<&'b Self> + for Scalar<'a, C, EccChip> { - fn mul_assign(&mut self, rhs: &'c Self) { - *self = (&self.loader).mul(self, rhs) + fn mul_assign(&mut self, rhs: &'b Self) { + *self = Halo2Loader::mul(&self.loader, self, rhs) } } #[derive(Clone)] -pub struct EcPoint<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> { - loader: Rc>, +pub struct EcPoint<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { + loader: Rc>, index: usize, - assigned: AssignedPoint, + assigned: EccChip::AssignedEcPoint, } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> - EcPoint<'a, 'b, C, LIMBS, BITS> -{ - pub fn assigned(&self) -> AssignedPoint { +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip> { + pub fn assigned(&self) -> EccChip::AssignedEcPoint { self.assigned.clone() } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> PartialEq - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for EcPoint<'a, C, EccChip> { fn eq(&self, other: &Self) -> bool { self.index == other.index } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> LoadedEcPoint - for EcPoint<'a, 'b, C, LIMBS, BITS> +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedEcPoint + for EcPoint<'a, C, EccChip> { - type Loader = Rc>; + type Loader = Rc>; fn loader(&self) -> &Self::Loader { &self.loader } fn multi_scalar_multiplication( - pairs: impl IntoIterator, Self)>, + pairs: impl IntoIterator, Self)>, ) -> Self { let pairs = pairs.into_iter().collect_vec(); let loader = &pairs[0].0.loader; @@ -570,42 +498,37 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> LoadedEcPoin .chain(if scaled.is_empty() { None } else { - let aux_generator = ::CurveExt::random(OsRng).to_affine(); - loader - .ecc_chip - .borrow_mut() - .assign_aux_generator( - &mut loader.ctx.borrow_mut(), - circuit::Value::known(aux_generator), - ) - .unwrap(); - loader - .ecc_chip - .borrow_mut() - .assign_aux(&mut loader.ctx.borrow_mut(), WINDOW_SIZE, scaled.len()) - .unwrap(); Some( loader .ecc_chip - .borrow() - .mul_batch_1d_horizontal(&mut loader.ctx.borrow_mut(), scaled, WINDOW_SIZE) + .borrow_mut() + .multi_scalar_multiplication(&mut loader.ctx_mut(), scaled) .unwrap(), ) }) .chain(non_scaled) .reduce(|acc, ec_point| { - (loader.ecc_chip().deref()) - .add(&mut loader.ctx.borrow_mut(), &acc, &ec_point) + EccInstructions::add( + loader.ecc_chip().deref(), + &mut loader.ctx_mut(), + &acc, + &ec_point, + ) + .unwrap() + }) + .map(|output| { + loader + .ecc_chip() + .normalize(&mut loader.ctx_mut(), &output) .unwrap() }) .unwrap(); + loader.ec_point(output) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Debug - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for EcPoint<'a, C, EccChip> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("EcPoint") .field("index", &self.index) @@ -614,110 +537,82 @@ impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Debug } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Add - for EcPoint<'a, 'b, C, LIMBS, BITS> +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> ScalarLoader + for Rc> { - type Output = Self; - - fn add(self, _: Self) -> Self::Output { - todo!() - } -} - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Sub - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - type Output = Self; - - fn sub(self, _: Self) -> Self::Output { - todo!() - } -} + type LoadedScalar = Scalar<'a, C, EccChip>; -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Neg - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - type Output = Self; - - fn neg(self) -> Self::Output { - todo!() - } -} - -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Add<&'c Self> - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - type Output = Self; - - fn add(self, rhs: &'c Self) -> Self::Output { - self + rhs.clone() - } -} - -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> Sub<&'c Self> - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - type Output = Self; - - fn sub(self, rhs: &'c Self) -> Self::Output { - self - rhs.clone() - } -} - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> AddAssign - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - fn add_assign(&mut self, rhs: Self) { - *self = self.clone() + rhs - } -} - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> SubAssign - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - fn sub_assign(&mut self, rhs: Self) { - *self = self.clone() - rhs - } -} - -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> AddAssign<&'c Self> - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - fn add_assign(&mut self, rhs: &'c Self) { - *self = self.clone() + rhs - } -} - -impl<'a, 'b, 'c, C: CurveAffine, const LIMBS: usize, const BITS: usize> SubAssign<&'c Self> - for EcPoint<'a, 'b, C, LIMBS, BITS> -{ - fn sub_assign(&mut self, rhs: &'c Self) { - *self = self.clone() - rhs - } -} - -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> ScalarLoader - for Rc> -{ - type LoadedScalar = Scalar<'a, 'b, C, LIMBS, BITS>; - - fn load_const(&self, value: &C::Scalar) -> Scalar<'a, 'b, C, LIMBS, BITS> { + fn load_const(&self, value: &C::Scalar) -> Scalar<'a, C, EccChip> { self.scalar(Value::Constant(*value)) } -} -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> EcPointLoader - for Rc> -{ - type LoadedEcPoint = EcPoint<'a, 'b, C, LIMBS, BITS>; - - fn ec_point_load_const(&self, ec_point: &C::CurveExt) -> EcPoint<'a, 'b, C, LIMBS, BITS> { - self.assign_const_ec_point(ec_point.to_affine()) + fn assert_eq( + &self, + annotation: &str, + lhs: &Scalar<'a, C, EccChip>, + rhs: &Scalar<'a, C, EccChip>, + ) -> Result<(), crate::Error> { + self.scalar_chip() + .assert_equal(&mut self.ctx_mut(), &lhs.assigned(), &rhs.assigned()) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + } + + fn sum_with_coeff_and_const( + &self, + values: &[(C::Scalar, &Scalar<'a, C, EccChip>)], + constant: C::Scalar, + ) -> Scalar<'a, C, EccChip> { + let values = values + .iter() + .map(|(coeff, value)| (*coeff, value.assigned())) + .collect_vec(); + self.scalar(Value::Assigned( + self.scalar_chip() + .sum_with_coeff_and_const(&mut self.ctx_mut(), &values, constant) + .unwrap(), + )) + } + + fn sum_products_with_coeff_and_const( + &self, + values: &[(C::Scalar, &Scalar<'a, C, EccChip>, &Scalar<'a, C, EccChip>)], + constant: C::Scalar, + ) -> Scalar<'a, C, EccChip> { + let values = values + .iter() + .map(|(coeff, lhs, rhs)| (*coeff, lhs.assigned(), rhs.assigned())) + .collect_vec(); + self.scalar(Value::Assigned( + self.scalar_chip() + .sum_products_with_coeff_and_const(&mut self.ctx_mut(), &values, constant) + .unwrap(), + )) + } +} + +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader + for Rc> +{ + type LoadedEcPoint = EcPoint<'a, C, EccChip>; + + fn ec_point_load_const(&self, ec_point: &C) -> EcPoint<'a, C, EccChip> { + self.assign_const_ec_point(*ec_point) + } + + fn ec_point_assert_eq( + &self, + annotation: &str, + lhs: &EcPoint<'a, C, EccChip>, + rhs: &EcPoint<'a, C, EccChip>, + ) -> Result<(), crate::Error> { + self.ecc_chip() + .assert_equal(&mut self.ctx_mut(), &lhs.assigned(), &rhs.assigned()) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) } } -impl<'a, 'b, C: CurveAffine, const LIMBS: usize, const BITS: usize> Loader - for Rc> +impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Loader + for Rc> { #[cfg(test)] fn start_cost_metering(&self, identifier: &str) { diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs new file mode 100644 index 00000000..67f06cc9 --- /dev/null +++ b/src/loader/halo2/shim.rs @@ -0,0 +1,403 @@ +use crate::util::arithmetic::{CurveAffine, FieldExt}; +use halo2_proofs::{ + circuit::{Cell, Value}, + plonk::Error, +}; +use std::fmt::Debug; + +pub trait Context: Debug { + fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error>; + + fn offset(&self) -> usize; +} + +pub trait IntegerInstructions<'a, F: FieldExt>: Clone + Debug { + type Context: Context; + type Integer: Clone + Debug; + type AssignedInteger: Clone + Debug; + + fn integer(&self, fe: F) -> Self::Integer; + + fn assign_integer( + &self, + ctx: &mut Self::Context, + integer: Value, + ) -> Result; + + fn assign_constant( + &self, + ctx: &mut Self::Context, + integer: F, + ) -> Result; + + fn sum_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F::Scalar, Self::AssignedInteger)], + constant: F::Scalar, + ) -> Result; + + fn sum_products_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F::Scalar, Self::AssignedInteger, Self::AssignedInteger)], + constant: F::Scalar, + ) -> Result; + + fn sub( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result; + + fn neg( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result; + + fn invert( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result; + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result<(), Error>; +} + +pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { + type Context: Context; + type ScalarChip: IntegerInstructions< + 'a, + C::Scalar, + Context = Self::Context, + Integer = Self::Scalar, + AssignedInteger = Self::AssignedScalar, + >; + type AssignedEcPoint: Clone + Debug; + type Scalar: Clone + Debug; + type AssignedScalar: Clone + Debug; + + fn scalar_chip(&self) -> &Self::ScalarChip; + + fn assign_constant( + &self, + ctx: &mut Self::Context, + point: C, + ) -> Result; + + fn assign_point( + &self, + ctx: &mut Self::Context, + point: Value, + ) -> Result; + + fn add( + &self, + ctx: &mut Self::Context, + p0: &Self::AssignedEcPoint, + p1: &Self::AssignedEcPoint, + ) -> Result; + + fn multi_scalar_multiplication( + &mut self, + ctx: &mut Self::Context, + pairs: Vec<(Self::AssignedEcPoint, Self::AssignedScalar)>, + ) -> Result; + + fn normalize( + &self, + ctx: &mut Self::Context, + point: &Self::AssignedEcPoint, + ) -> Result; + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedEcPoint, + b: &Self::AssignedEcPoint, + ) -> Result<(), Error>; +} + +mod halo2_wrong { + use crate::{ + loader::halo2::{Context, EccInstructions, IntegerInstructions}, + util::{ + arithmetic::{CurveAffine, FieldExt, Group}, + Itertools, + }, + }; + use halo2_proofs::{ + circuit::{AssignedCell, Cell, Value}, + plonk::Error, + }; + use halo2_wrong_ecc::{ + integer::rns::Common, + maingate::{ + CombinationOption, CombinationOptionCommon, MainGate, MainGateInstructions, RegionCtx, + Term, + }, + AssignedPoint, BaseFieldEccChip, + }; + use rand::rngs::OsRng; + + impl<'a, F: FieldExt> Context for RegionCtx<'a, F> { + fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { + self.constrain_equal(lhs, rhs) + } + + fn offset(&self) -> usize { + self.offset() + } + } + + impl<'a, F: FieldExt> IntegerInstructions<'a, F> for MainGate { + type Context = RegionCtx<'a, F>; + type Integer = F; + type AssignedInteger = AssignedCell; + + fn integer(&self, scalar: F) -> Self::Integer { + scalar + } + + fn assign_integer( + &self, + ctx: &mut Self::Context, + integer: Value, + ) -> Result { + self.assign_value(ctx, integer) + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + integer: F, + ) -> Result { + MainGateInstructions::assign_constant(self, ctx, integer) + } + + fn sum_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F, Self::AssignedInteger)], + constant: F, + ) -> Result { + self.compose( + ctx, + &values + .iter() + .map(|(coeff, assigned)| Term::Assigned(assigned, *coeff)) + .collect_vec(), + constant, + ) + } + + fn sum_products_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F, Self::AssignedInteger, Self::AssignedInteger)], + constant: F, + ) -> Result { + match values.len() { + 0 => MainGateInstructions::assign_constant(self, ctx, constant), + 1 => { + let (scalar, lhs, rhs) = &values[0]; + let output = lhs + .value() + .zip(rhs.value()) + .map(|(lhs, rhs)| *scalar * lhs * rhs + constant); + + Ok(self + .apply( + ctx, + [ + Term::Zero, + Term::Zero, + Term::assigned_to_mul(lhs), + Term::assigned_to_mul(rhs), + Term::unassigned_to_sub(output), + ], + constant, + CombinationOption::OneLinerDoubleMul(*scalar), + )? + .swap_remove(4)) + } + _ => { + let (scalar, lhs, rhs) = &values[0]; + self.apply( + ctx, + [Term::assigned_to_mul(lhs), Term::assigned_to_mul(rhs)], + constant, + CombinationOptionCommon::CombineToNextScaleMul(-F::one(), *scalar).into(), + )?; + let acc = + Value::known(*scalar) * lhs.value() * rhs.value() + Value::known(constant); + let output = values.iter().skip(1).fold( + Ok::<_, Error>(acc), + |acc, (scalar, lhs, rhs)| { + acc.and_then(|acc| { + self.apply( + ctx, + [ + Term::assigned_to_mul(lhs), + Term::assigned_to_mul(rhs), + Term::Zero, + Term::Zero, + Term::Unassigned(acc, F::one()), + ], + F::zero(), + CombinationOptionCommon::CombineToNextScaleMul( + -F::one(), + *scalar, + ) + .into(), + )?; + Ok(acc + Value::known(*scalar) * lhs.value() * rhs.value()) + }) + }, + )?; + self.apply( + ctx, + [ + Term::Zero, + Term::Zero, + Term::Zero, + Term::Zero, + Term::Unassigned(output, F::zero()), + ], + F::zero(), + CombinationOptionCommon::OneLinerAdd.into(), + ) + .map(|mut outputs| outputs.swap_remove(4)) + } + } + } + + fn sub( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result { + MainGateInstructions::sub(self, ctx, a, b) + } + + fn neg( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result { + MainGateInstructions::neg_with_constant(self, ctx, a, F::zero()) + } + + fn invert( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result { + MainGateInstructions::invert_unsafe(self, ctx, a) + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result<(), Error> { + let mut eq = true; + a.value().zip(b.value()).map(|(lhs, rhs)| { + eq &= lhs == rhs; + }); + MainGateInstructions::assert_equal(self, ctx, a, b) + .and(eq.then_some(()).ok_or(Error::Synthesis)) + } + } + + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> EccInstructions<'a, C> + for BaseFieldEccChip + { + type Context = RegionCtx<'a, C::Scalar>; + type ScalarChip = MainGate; + type AssignedEcPoint = AssignedPoint; + type Scalar = C::Scalar; + type AssignedScalar = AssignedCell; + + fn scalar_chip(&self) -> &Self::ScalarChip { + self.main_gate() + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + point: C, + ) -> Result { + self.assign_constant(ctx, point) + } + + fn assign_point( + &self, + ctx: &mut Self::Context, + point: Value, + ) -> Result { + self.assign_point(ctx, point) + } + + fn add( + &self, + ctx: &mut Self::Context, + p0: &Self::AssignedEcPoint, + p1: &Self::AssignedEcPoint, + ) -> Result { + self.add(ctx, p0, p1) + } + + fn multi_scalar_multiplication( + &mut self, + ctx: &mut Self::Context, + pairs: Vec<(Self::AssignedEcPoint, Self::AssignedScalar)>, + ) -> Result { + const WINDOW_SIZE: usize = 3; + match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { + Err(_) => { + if self.assign_aux(ctx, WINDOW_SIZE, pairs.len()).is_err() { + let aux_generator = Value::known(C::Curve::random(OsRng).into()); + self.assign_aux_generator(ctx, aux_generator)?; + self.assign_aux(ctx, WINDOW_SIZE, pairs.len())?; + } + self.mul_batch_1d_horizontal(ctx, pairs, WINDOW_SIZE) + } + result => result, + } + } + + fn normalize( + &self, + ctx: &mut Self::Context, + point: &Self::AssignedEcPoint, + ) -> Result { + self.normalize(ctx, point) + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedEcPoint, + b: &Self::AssignedEcPoint, + ) -> Result<(), Error> { + let mut eq = true; + [(a.x(), b.x()), (a.y(), b.y())].map(|(lhs, rhs)| { + lhs.integer().zip(rhs.integer()).map(|(lhs, rhs)| { + eq &= lhs.value() == rhs.value(); + }); + }); + self.assert_equal(ctx, a, b) + .and(eq.then_some(()).ok_or(Error::Synthesis)) + } + } +} diff --git a/src/loader/halo2/test.rs b/src/loader/halo2/test.rs new file mode 100644 index 00000000..08551fe0 --- /dev/null +++ b/src/loader/halo2/test.rs @@ -0,0 +1,66 @@ +use crate::{ + util::{arithmetic::CurveAffine, Itertools}, + Protocol, +}; +use halo2_proofs::circuit::Value; + +pub struct Snark { + pub protocol: Protocol, + pub instances: Vec>, + pub proof: Vec, +} + +impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + assert_eq!( + protocol.num_instance, + instances + .iter() + .map(|instances| instances.len()) + .collect_vec() + ); + Snark { + protocol, + instances, + proof, + } + } +} + +pub struct SnarkWitness { + pub protocol: Protocol, + pub instances: Vec>>, + pub proof: Value>, +} + +impl From> for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } +} + +impl SnarkWitness { + pub fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + pub fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } +} diff --git a/src/loader/halo2/transcript.rs b/src/loader/halo2/transcript.rs deleted file mode 100644 index 5bfe57ba..00000000 --- a/src/loader/halo2/transcript.rs +++ /dev/null @@ -1,363 +0,0 @@ -use crate::{ - loader::{ - halo2::loader::{EcPoint, Halo2Loader, Scalar, Value}, - native::NativeLoader, - }, - util::{Curve, GroupEncoding, PrimeField, Transcript, TranscriptRead}, - Error, -}; -use halo2_curves::{Coordinates, CurveAffine}; -use halo2_proofs::circuit; -use halo2_wrong_ecc::integer::rns::{Common, Integer, Rns}; -use halo2_wrong_transcript::{NativeRepresentation, PointRepresentation, TranscriptChip}; -use poseidon::{Poseidon, Spec}; -use std::{ - io::{self, Read, Write}, - marker::PhantomData, - rc::Rc, -}; - -pub struct PoseidonTranscript< - C: CurveAffine, - L, - S, - B, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, -> { - loader: L, - stream: S, - buf: B, - rns: Rc>, - _marker: PhantomData<(C, E)>, -} - -impl< - 'a, - 'b, - C: CurveAffine, - R: Read, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > - PoseidonTranscript< - C, - Rc>, - circuit::Value, - TranscriptChip, - E, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - pub fn new( - loader: &Rc>, - stream: circuit::Value, - ) -> Self { - let transcript_chip = TranscriptChip::new( - &mut loader.ctx_mut(), - &Spec::new(R_F, R_P), - loader.ecc_chip().clone(), - ) - .unwrap(); - Self { - loader: loader.clone(), - stream, - buf: transcript_chip, - rns: Rc::new(Rns::::construct()), - _marker: PhantomData, - } - } -} - -impl< - 'a, - 'b, - C: CurveAffine, - R: Read, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > Transcript>> - for PoseidonTranscript< - C, - Rc>, - circuit::Value, - TranscriptChip, - E, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn squeeze_challenge(&mut self) -> Scalar<'a, 'b, C, LIMBS, BITS> { - let assigned = self.buf.squeeze(&mut self.loader.ctx_mut()).unwrap(); - self.loader.scalar(Value::Assigned(assigned)) - } - - fn common_scalar(&mut self, scalar: &Scalar<'a, 'b, C, LIMBS, BITS>) -> Result<(), Error> { - self.buf.write_scalar(&scalar.assigned()); - Ok(()) - } - - fn common_ec_point(&mut self, ec_point: &EcPoint<'a, 'b, C, LIMBS, BITS>) -> Result<(), Error> { - self.buf - .write_point(&mut self.loader.ctx_mut(), &ec_point.assigned()) - .unwrap(); - Ok(()) - } -} - -impl< - 'a, - 'b, - C: CurveAffine, - R: Read, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead>> - for PoseidonTranscript< - C, - Rc>, - circuit::Value, - TranscriptChip, - E, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn read_scalar(&mut self) -> Result, Error> { - let scalar = self.stream.as_mut().and_then(|stream| { - let mut data = ::Repr::default(); - if stream.read_exact(data.as_mut()).is_err() { - return circuit::Value::unknown(); - } - Option::::from(C::Scalar::from_repr(data)) - .map(circuit::Value::known) - .unwrap_or_else(circuit::Value::unknown) - }); - let scalar = self.loader.assign_scalar(scalar); - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result, Error> { - let ec_point = self.stream.as_mut().and_then(|stream| { - let mut compressed = C::Repr::default(); - if stream.read_exact(compressed.as_mut()).is_err() { - return circuit::Value::unknown(); - } - Option::::from(C::from_bytes(&compressed)) - .map(circuit::Value::known) - .unwrap_or_else(circuit::Value::unknown) - }); - let ec_point = self.loader.assign_ec_point(ec_point); - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl< - C: CurveAffine, - S, - E: PointRepresentation, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > - PoseidonTranscript< - C, - NativeLoader, - S, - Poseidon, - E, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - pub fn new(stream: S) -> Self { - Self { - loader: NativeLoader, - stream, - buf: Poseidon::new(R_F, R_P), - rns: Rc::new(Rns::::construct()), - _marker: PhantomData, - } - } -} - -impl< - C: CurveAffine, - S, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > Transcript - for PoseidonTranscript< - C, - NativeLoader, - S, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn squeeze_challenge(&mut self) -> C::Scalar { - self.buf.squeeze() - } - - fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { - self.buf.update(&[*scalar]); - Ok(()) - } - - fn common_ec_point(&mut self, ec_point: &C::CurveExt) -> Result<(), Error> { - let coords: Coordinates = - Option::from(ec_point.to_affine().coordinates()).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Cannot write points at infinity to the transcript".to_string(), - ) - })?; - let x = Integer::from_fe(*coords.x(), self.rns.clone()); - let y = Integer::from_fe(*coords.y(), self.rns.clone()); - self.buf.update(&[x.native(), y.native()]); - Ok(()) - } -} - -impl< - C: CurveAffine, - R: Read, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead - for PoseidonTranscript< - C, - NativeLoader, - R, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn read_scalar(&mut self) -> Result { - let mut data = ::Repr::default(); - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid scalar encoding in proof".to_string(), - ) - })?; - self.common_scalar(&scalar)?; - Ok(scalar) - } - - fn read_ec_point(&mut self) -> Result { - let mut data = C::Repr::default(); - self.stream - .read_exact(data.as_mut()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; - let ec_point = Option::::from( - ::from_bytes(&data).map(|ec_point| ec_point.to_curve()), - ) - .ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid elliptic curve point encoding in proof".to_string(), - ) - })?; - self.common_ec_point(&ec_point)?; - Ok(ec_point) - } -} - -impl< - C: CurveAffine, - W: Write, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > - PoseidonTranscript< - C, - NativeLoader, - W, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - pub fn stream_mut(&mut self) -> &mut W { - &mut self.stream - } - - pub fn finalize(self) -> W { - self.stream - } -} diff --git a/src/loader/native.rs b/src/loader/native.rs index d0ee9fb5..6bf9f7c4 100644 --- a/src/loader/native.rs +++ b/src/loader/native.rs @@ -1,4 +1,85 @@ -mod accumulation; -mod loader; +use crate::{ + loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, + util::arithmetic::{Curve, CurveAffine, FieldOps, PrimeField}, + Error, +}; +use lazy_static::lazy_static; +use std::fmt::Debug; -pub use loader::NativeLoader; +lazy_static! { + pub static ref LOADER: NativeLoader = NativeLoader; +} + +#[derive(Clone, Debug)] +pub struct NativeLoader; + +impl LoadedEcPoint for C { + type Loader = NativeLoader; + + fn loader(&self) -> &NativeLoader { + &LOADER + } + + fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { + pairs + .into_iter() + .map(|(scalar, base)| base * scalar) + .reduce(|acc, value| acc + value) + .unwrap() + .to_affine() + } +} + +impl FieldOps for F { + fn invert(&self) -> Option { + self.invert().into() + } +} + +impl LoadedScalar for F { + type Loader = NativeLoader; + + fn loader(&self) -> &NativeLoader { + &LOADER + } +} + +impl EcPointLoader for NativeLoader { + type LoadedEcPoint = C; + + fn ec_point_load_const(&self, value: &C) -> Self::LoadedEcPoint { + *value + } + + fn ec_point_assert_eq( + &self, + annotation: &str, + lhs: &Self::LoadedEcPoint, + rhs: &Self::LoadedEcPoint, + ) -> Result<(), Error> { + lhs.eq(rhs) + .then_some(()) + .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) + } +} + +impl ScalarLoader for NativeLoader { + type LoadedScalar = F; + + fn load_const(&self, value: &F) -> Self::LoadedScalar { + *value + } + + fn assert_eq( + &self, + annotation: &str, + lhs: &Self::LoadedScalar, + rhs: &Self::LoadedScalar, + ) -> Result<(), Error> { + lhs.eq(rhs) + .then_some(()) + .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) + } +} + +impl Loader for NativeLoader {} diff --git a/src/loader/native/accumulation.rs b/src/loader/native/accumulation.rs deleted file mode 100644 index d0b67dc9..00000000 --- a/src/loader/native/accumulation.rs +++ /dev/null @@ -1,111 +0,0 @@ -use crate::{ - loader::native::NativeLoader, - protocol::Protocol, - scheme::kzg::{AccumulationStrategy, Accumulator, SameCurveAccumulation, MSM}, - util::{fe_from_limbs, Curve, Group, Itertools, PrimeCurveAffine, Transcript}, - Error, -}; -use halo2_curves::{ - pairing::{MillerLoopResult, MultiMillerLoop}, - CurveAffine, CurveExt, -}; - -impl - SameCurveAccumulation -{ - pub fn finalize(self, g1: C) -> (C, C) { - self.accumulator.unwrap().evaluate(g1) - } -} - -impl - SameCurveAccumulation -{ - pub fn decide>( - self, - g1: M::G1Affine, - g2: M::G2Affine, - s_g2: M::G2Affine, - ) -> bool { - let (lhs, rhs) = self.finalize(g1.to_curve()); - - let g2 = M::G2Prepared::from(g2); - let minus_s_g2 = M::G2Prepared::from(-s_g2); - - let terms = [(&lhs.into(), &g2), (&rhs.into(), &minus_s_g2)]; - M::multi_miller_loop(&terms) - .final_exponentiation() - .is_identity() - .into() - } -} - -impl AccumulationStrategy - for SameCurveAccumulation -where - C: CurveExt, - T: Transcript, -{ - type Output = P; - - fn extract_accumulator( - &self, - protocol: &Protocol, - _: &NativeLoader, - transcript: &mut T, - statements: &[Vec], - ) -> Option> { - let accumulator_indices = protocol.accumulator_indices.as_ref()?; - - let challenges = transcript.squeeze_n_challenges(accumulator_indices.len()); - let accumulators = accumulator_indices - .iter() - .map(|indices| { - assert_eq!(indices.len(), 4 * LIMBS); - let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = indices - .chunks(4) - .into_iter() - .map(|indices| { - fe_from_limbs::<_, _, LIMBS, BITS>( - indices - .iter() - .map(|index| statements[index.0][index.1]) - .collect_vec() - .try_into() - .unwrap(), - ) - }) - .collect_vec() - .try_into() - .unwrap(); - let lhs = ::from_xy(lhs_x, lhs_y) - .unwrap() - .to_curve(); - let rhs = ::from_xy(rhs_x, rhs_y) - .unwrap() - .to_curve(); - Accumulator::new(MSM::base(lhs), MSM::base(rhs)) - }) - .collect_vec(); - - Some(Accumulator::random_linear_combine( - challenges.into_iter().zip(accumulators), - )) - } - - fn process( - &mut self, - _: &NativeLoader, - transcript: &mut T, - proof: P, - accumulator: Accumulator, - ) -> Result { - self.accumulator = Some(match self.accumulator.take() { - Some(curr_accumulator) => { - accumulator + curr_accumulator * &transcript.squeeze_challenge() - } - None => accumulator, - }); - Ok(proof) - } -} diff --git a/src/loader/native/loader.rs b/src/loader/native/loader.rs deleted file mode 100644 index 12bb475e..00000000 --- a/src/loader/native/loader.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::{ - loader::{EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, - util::{Curve, FieldOps, PrimeField}, -}; -use lazy_static::lazy_static; -use std::fmt::Debug; - -lazy_static! { - static ref LOADER: NativeLoader = NativeLoader; -} - -impl LoadedEcPoint for C { - type Loader = NativeLoader; - - fn loader(&self) -> &NativeLoader { - &LOADER - } - - fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { - pairs - .into_iter() - .map(|(scalar, base)| base * scalar) - .reduce(|acc, value| acc + value) - .unwrap() - } -} - -impl FieldOps for F { - fn invert(&self) -> Option { - self.invert().into() - } -} - -impl LoadedScalar for F { - type Loader = NativeLoader; - - fn loader(&self) -> &NativeLoader { - &LOADER - } -} - -#[derive(Clone, Debug)] -pub struct NativeLoader; - -impl EcPointLoader for NativeLoader { - type LoadedEcPoint = C; - - fn ec_point_load_const(&self, value: &C) -> Self::LoadedEcPoint { - *value - } -} - -impl ScalarLoader for NativeLoader { - type LoadedScalar = F; - - fn load_const(&self, value: &F) -> Self::LoadedScalar { - *value - } -} - -impl Loader for NativeLoader {} diff --git a/src/pcs.rs b/src/pcs.rs new file mode 100644 index 00000000..65804895 --- /dev/null +++ b/src/pcs.rs @@ -0,0 +1,138 @@ +use crate::{ + loader::{native::NativeLoader, Loader}, + util::{ + arithmetic::{CurveAffine, PrimeField}, + msm::Msm, + transcript::{TranscriptRead, TranscriptWrite}, + }, + Error, +}; +use rand::Rng; +use std::fmt::Debug; + +pub mod kzg; + +pub trait PolynomialCommitmentScheme: Clone + Debug +where + C: CurveAffine, + L: Loader, +{ + type Accumulator: Clone + Debug; +} + +#[derive(Clone, Debug)] +pub struct Query { + pub poly: usize, + pub shift: F, + pub eval: T, +} + +impl Query { + pub fn with_evaluation(self, eval: T) -> Query { + Query { + poly: self.poly, + shift: self.shift, + eval, + } + } +} + +pub trait MultiOpenScheme: PolynomialCommitmentScheme +where + C: CurveAffine, + L: Loader, +{ + type SuccinctVerifyingKey: Clone + Debug; + type Proof: Clone + Debug; + + fn read_proof( + svk: &Self::SuccinctVerifyingKey, + queries: &[Query], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead; + + fn succinct_verify( + svk: &Self::SuccinctVerifyingKey, + commitments: &[Msm], + point: &L::LoadedScalar, + queries: &[Query], + proof: &Self::Proof, + ) -> Result; +} + +pub trait Decider: PolynomialCommitmentScheme +where + C: CurveAffine, + L: Loader, +{ + type DecidingKey: Clone + Debug; + type Output: Clone + Debug; + + fn decide(dk: &Self::DecidingKey, accumulator: Self::Accumulator) -> Self::Output; + + fn decide_all(dk: &Self::DecidingKey, accumulators: Vec) -> Self::Output; +} + +pub trait AccumulationScheme: Clone + Debug +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme, +{ + type VerifyingKey: Clone + Debug; + type Proof: Clone + Debug; + + fn read_proof( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead; + + fn verify( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + proof: &Self::Proof, + ) -> Result; +} + +pub trait AccumulationSchemeProver: AccumulationScheme +where + C: CurveAffine, + PCS: PolynomialCommitmentScheme, +{ + type ProvingKey: Clone + Debug; + + fn create_proof( + pk: &Self::ProvingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + rng: R, + ) -> Result + where + T: TranscriptWrite, + R: Rng; +} + +pub trait AccumulatorEncoding: Clone + Debug +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme, +{ + fn from_repr(repr: Vec) -> Result; +} + +impl AccumulatorEncoding for () +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme, +{ + fn from_repr(_: Vec) -> Result { + unimplemented!() + } +} diff --git a/src/pcs/kzg.rs b/src/pcs/kzg.rs new file mode 100644 index 00000000..9f10bd44 --- /dev/null +++ b/src/pcs/kzg.rs @@ -0,0 +1,45 @@ +use crate::{ + loader::Loader, + pcs::PolynomialCommitmentScheme, + util::arithmetic::{CurveAffine, MultiMillerLoop}, +}; +use std::{fmt::Debug, marker::PhantomData}; + +mod accumulation; +mod accumulator; +mod decider; +mod multiopen; + +pub use accumulation::{KzgAs, KzgAsProvingKey, KzgAsVerifyingKey}; +pub use accumulator::{KzgAccumulator, LimbsEncoding}; +pub use decider::KzgDecidingKey; +pub use multiopen::{Bdfg21, Bdfg21Proof, Gwc19, Gwc19Proof}; + +#[derive(Clone, Debug)] +pub struct Kzg(PhantomData<(M, MOS)>); + +impl PolynomialCommitmentScheme for Kzg +where + M: MultiMillerLoop, + L: Loader, + MOS: Clone + Debug, +{ + type Accumulator = KzgAccumulator; +} + +#[derive(Clone, Copy, Debug)] +pub struct KzgSuccinctVerifyingKey { + pub g: C, +} + +impl KzgSuccinctVerifyingKey { + pub fn new(g: C) -> Self { + Self { g } + } +} + +impl From for KzgSuccinctVerifyingKey { + fn from(g: C) -> KzgSuccinctVerifyingKey { + KzgSuccinctVerifyingKey::new(g) + } +} diff --git a/src/pcs/kzg/accumulation.rs b/src/pcs/kzg/accumulation.rs new file mode 100644 index 00000000..cd13fd00 --- /dev/null +++ b/src/pcs/kzg/accumulation.rs @@ -0,0 +1,196 @@ +use crate::{ + loader::{native::NativeLoader, LoadedScalar, Loader}, + pcs::{ + kzg::KzgAccumulator, AccumulationScheme, AccumulationSchemeProver, + PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{Curve, CurveAffine, Field}, + msm::Msm, + transcript::{TranscriptRead, TranscriptWrite}, + }, + Error, +}; +use rand::Rng; +use std::marker::PhantomData; + +#[derive(Clone, Debug)] +pub struct KzgAs(PhantomData); + +impl AccumulationScheme for KzgAs +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + type VerifyingKey = KzgAsVerifyingKey; + type Proof = KzgAsProof; + + fn read_proof( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + KzgAsProof::read(vk, instances, transcript) + } + + fn verify( + _: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + proof: &Self::Proof, + ) -> Result { + let (lhs, rhs) = instances + .iter() + .cloned() + .map(|accumulator| (accumulator.lhs, accumulator.rhs)) + .chain(proof.blind.clone()) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let powers_of_r = proof.r.powers(lhs.len()); + let [lhs, rhs] = [lhs, rhs].map(|msms| { + msms.into_iter() + .zip(powers_of_r.iter()) + .map(|(msm, r)| Msm::::base(msm) * r) + .sum::>() + .evaluate(None) + }); + + Ok(KzgAccumulator::new(lhs, rhs)) + } +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct KzgAsProvingKey(pub Option<(C, C)>); + +impl KzgAsProvingKey { + pub fn new(g: Option<(C, C)>) -> Self { + Self(g) + } + + pub fn zk(&self) -> bool { + self.0.is_some() + } + + pub fn vk(&self) -> KzgAsVerifyingKey { + KzgAsVerifyingKey(self.zk()) + } +} + +#[derive(Clone, Copy, Debug, Default)] +pub struct KzgAsVerifyingKey(bool); + +impl KzgAsVerifyingKey { + pub fn zk(&self) -> bool { + self.0 + } +} + +#[derive(Clone, Debug)] +pub struct KzgAsProof +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + blind: Option<(L::LoadedEcPoint, L::LoadedEcPoint)>, + r: L::LoadedScalar, + _marker: PhantomData, +} + +impl KzgAsProof +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + fn read( + vk: &KzgAsVerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + assert!(!instances.is_empty()); + + for accumulator in instances { + transcript.common_ec_point(&accumulator.lhs)?; + transcript.common_ec_point(&accumulator.rhs)?; + } + + let blind = vk + .zk() + .then(|| Ok((transcript.read_ec_point()?, transcript.read_ec_point()?))) + .transpose()?; + + let r = transcript.squeeze_challenge(); + + Ok(Self { + blind, + r, + _marker: PhantomData, + }) + } +} + +impl AccumulationSchemeProver for KzgAs +where + C: CurveAffine, + PCS: PolynomialCommitmentScheme>, +{ + type ProvingKey = KzgAsProvingKey; + + fn create_proof( + pk: &Self::ProvingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + rng: R, + ) -> Result + where + T: TranscriptWrite, + R: Rng, + { + assert!(!instances.is_empty()); + + for accumulator in instances { + transcript.common_ec_point(&accumulator.lhs)?; + transcript.common_ec_point(&accumulator.rhs)?; + } + + let blind = pk + .zk() + .then(|| { + let s = C::Scalar::random(rng); + let (g, s_g) = pk.0.unwrap(); + let lhs = (s_g * s).to_affine(); + let rhs = (g * s).to_affine(); + transcript.write_ec_point(lhs)?; + transcript.write_ec_point(rhs)?; + Ok((lhs, rhs)) + }) + .transpose()?; + + let r = transcript.squeeze_challenge(); + + let (lhs, rhs) = instances + .iter() + .cloned() + .map(|accumulator| (accumulator.lhs, accumulator.rhs)) + .chain(blind) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let powers_of_r = r.powers(lhs.len()); + let [lhs, rhs] = [lhs, rhs].map(|msms| { + msms.into_iter() + .zip(powers_of_r.iter()) + .map(|(msm, power_of_r)| Msm::::base(msm) * power_of_r) + .sum::>() + .evaluate(None) + }); + + Ok(KzgAccumulator::new(lhs, rhs)) + } +} diff --git a/src/pcs/kzg/accumulator.rs b/src/pcs/kzg/accumulator.rs new file mode 100644 index 00000000..17c7bf83 --- /dev/null +++ b/src/pcs/kzg/accumulator.rs @@ -0,0 +1,208 @@ +use crate::{loader::Loader, util::arithmetic::CurveAffine}; +use std::fmt::Debug; + +#[derive(Clone, Debug)] +pub struct KzgAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub lhs: L::LoadedEcPoint, + pub rhs: L::LoadedEcPoint, +} + +impl KzgAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub fn new(lhs: L::LoadedEcPoint, rhs: L::LoadedEcPoint) -> Self { + Self { lhs, rhs } + } +} + +/// `AccumulatorEncoding` that encodes `Accumulator` into limbs. +/// +/// Since in circuit everything are in scalar field, but `Accumulator` might contain base field elements, so we split them into limbs. +/// The const generic `LIMBS` and `BITS` respectively represents how many limbs +/// a base field element are split into and how many bits each limbs could have. +#[derive(Clone, Debug)] +pub struct LimbsEncoding; + +mod native { + use crate::{ + loader::native::NativeLoader, + pcs::{ + kzg::{KzgAccumulator, LimbsEncoding}, + AccumulatorEncoding, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{fe_from_limbs, CurveAffine}, + Itertools, + }, + Error, + }; + + impl AccumulatorEncoding + for LimbsEncoding + where + C: CurveAffine, + PCS: PolynomialCommitmentScheme< + C, + NativeLoader, + Accumulator = KzgAccumulator, + >, + { + fn from_repr(limbs: Vec) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = limbs + .chunks(LIMBS) + .into_iter() + .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) + .collect_vec() + .try_into() + .unwrap(); + let accumulator = KzgAccumulator::new( + C::from_xy(lhs_x, lhs_y).unwrap(), + C::from_xy(rhs_x, rhs_y).unwrap(), + ); + + Ok(accumulator) + } + } +} + +#[cfg(feature = "loader_evm")] +mod evm { + use crate::{ + loader::evm::{EvmLoader, Scalar}, + pcs::{ + kzg::{KzgAccumulator, LimbsEncoding}, + AccumulatorEncoding, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{CurveAffine, PrimeField}, + Itertools, + }, + Error, + }; + use std::rc::Rc; + + impl AccumulatorEncoding, PCS> + for LimbsEncoding + where + C: CurveAffine, + C::Scalar: PrimeField, + PCS: PolynomialCommitmentScheme< + C, + Rc, + Accumulator = KzgAccumulator>, + >, + { + fn from_repr(limbs: Vec) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let loader = limbs[0].loader(); + + let [lhs_x, lhs_y, rhs_x, rhs_y]: [[_; LIMBS]; 4] = limbs + .chunks(LIMBS) + .into_iter() + .map(|limbs| limbs.to_vec().try_into().unwrap()) + .collect_vec() + .try_into() + .unwrap(); + let accumulator = KzgAccumulator::new( + loader.ec_point_from_limbs::(lhs_x, lhs_y), + loader.ec_point_from_limbs::(rhs_x, rhs_y), + ); + + Ok(accumulator) + } + } +} + +#[cfg(feature = "loader_halo2")] +mod halo2 { + use crate::{ + loader::halo2::{Context, EccInstructions, Halo2Loader, Scalar, Valuetools}, + pcs::{ + kzg::{KzgAccumulator, LimbsEncoding}, + AccumulatorEncoding, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{fe_from_limbs, CurveAffine}, + Itertools, + }, + Error, + }; + use halo2_proofs::circuit::Value; + use halo2_wrong_ecc::{maingate::AssignedValue, AssignedPoint}; + use std::{iter, rc::Rc}; + + fn ec_point_from_assigned_limbs( + limbs: &[AssignedValue], + ) -> Value { + assert_eq!(limbs.len(), 2 * LIMBS); + + let [x, y] = [&limbs[..LIMBS], &limbs[LIMBS..]].map(|limbs| { + limbs + .iter() + .map(|assigned| assigned.value()) + .fold_zipped(Vec::new(), |mut acc, limb| { + acc.push(*limb); + acc + }) + .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) + }); + + x.zip(y).map(|(x, y)| C::from_xy(x, y).unwrap()) + } + + impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> + AccumulatorEncoding>, PCS> for LimbsEncoding + where + C: CurveAffine, + PCS: PolynomialCommitmentScheme< + C, + Rc>, + Accumulator = KzgAccumulator>>, + >, + EccChip: EccInstructions< + 'a, + C, + AssignedEcPoint = AssignedPoint<::Base, C::Scalar, LIMBS, BITS>, + AssignedScalar = AssignedValue, + >, + { + fn from_repr(limbs: Vec>) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let loader = limbs[0].loader(); + + let assigned_limbs = limbs.iter().map(|limb| limb.assigned()).collect_vec(); + let [lhs, rhs] = [&assigned_limbs[..2 * LIMBS], &assigned_limbs[2 * LIMBS..]].map( + |assigned_limbs| { + let ec_point = ec_point_from_assigned_limbs::<_, LIMBS, BITS>(assigned_limbs); + loader.assign_ec_point(ec_point) + }, + ); + + for (src, dst) in assigned_limbs.iter().zip( + iter::empty() + .chain(lhs.assigned().x().limbs()) + .chain(lhs.assigned().y().limbs()) + .chain(rhs.assigned().x().limbs()) + .chain(rhs.assigned().y().limbs()), + ) { + loader + .ctx_mut() + .constrain_equal(src.cell(), dst.as_ref().cell()) + .unwrap(); + } + let accumulator = KzgAccumulator::new(lhs, rhs); + + Ok(accumulator) + } + } +} diff --git a/src/pcs/kzg/decider.rs b/src/pcs/kzg/decider.rs new file mode 100644 index 00000000..b6957883 --- /dev/null +++ b/src/pcs/kzg/decider.rs @@ -0,0 +1,162 @@ +use crate::util::arithmetic::MultiMillerLoop; +use std::marker::PhantomData; + +#[derive(Debug, Clone, Copy)] +pub struct KzgDecidingKey { + pub g2: M::G2Affine, + pub s_g2: M::G2Affine, + _marker: PhantomData, +} + +impl KzgDecidingKey { + pub fn new(g2: M::G2Affine, s_g2: M::G2Affine) -> Self { + Self { + g2, + s_g2, + _marker: PhantomData, + } + } +} + +impl From<(M::G2Affine, M::G2Affine)> for KzgDecidingKey { + fn from((g2, s_g2): (M::G2Affine, M::G2Affine)) -> KzgDecidingKey { + KzgDecidingKey::new(g2, s_g2) + } +} + +mod native { + use crate::{ + loader::native::NativeLoader, + pcs::{ + kzg::{Kzg, KzgAccumulator, KzgDecidingKey}, + Decider, + }, + util::arithmetic::{Group, MillerLoopResult, MultiMillerLoop}, + }; + use std::fmt::Debug; + + impl Decider for Kzg + where + M: MultiMillerLoop, + MOS: Clone + Debug, + { + type DecidingKey = KzgDecidingKey; + type Output = bool; + + fn decide( + dk: &Self::DecidingKey, + KzgAccumulator { lhs, rhs }: KzgAccumulator, + ) -> bool { + let terms = [(&lhs, &dk.g2.into()), (&rhs, &(-dk.s_g2).into())]; + M::multi_miller_loop(&terms) + .final_exponentiation() + .is_identity() + .into() + } + + fn decide_all( + dk: &Self::DecidingKey, + accumulators: Vec>, + ) -> bool { + !accumulators + .into_iter() + .any(|accumulator| !Self::decide(dk, accumulator)) + } + } +} + +#[cfg(feature = "loader_evm")] +mod evm { + use crate::{ + loader::{ + evm::{loader::Value, EvmLoader}, + LoadedScalar, + }, + pcs::{ + kzg::{Kzg, KzgAccumulator, KzgDecidingKey}, + Decider, + }, + util::{ + arithmetic::{CurveAffine, MultiMillerLoop, PrimeField}, + msm::Msm, + }, + }; + use ethereum_types::U256; + use std::{fmt::Debug, rc::Rc}; + + impl Decider> for Kzg + where + M: MultiMillerLoop, + M::Scalar: PrimeField, + MOS: Clone + Debug, + { + type DecidingKey = KzgDecidingKey; + type Output = (); + + fn decide( + dk: &Self::DecidingKey, + KzgAccumulator { lhs, rhs }: KzgAccumulator>, + ) { + let loader = lhs.loader(); + let [g2, minus_s_g2] = [dk.g2, -dk.s_g2].map(|ec_point| { + let coordinates = ec_point.coordinates().unwrap(); + let x = coordinates.x().to_repr(); + let y = coordinates.y().to_repr(); + ( + U256::from_little_endian(&x.as_ref()[32..]), + U256::from_little_endian(&x.as_ref()[..32]), + U256::from_little_endian(&y.as_ref()[32..]), + U256::from_little_endian(&y.as_ref()[..32]), + ) + }); + loader.pairing(&lhs, g2, &rhs, minus_s_g2); + } + + fn decide_all( + dk: &Self::DecidingKey, + mut accumulators: Vec>>, + ) { + assert!(!accumulators.is_empty()); + + let accumulator = if accumulators.len() == 1 { + accumulators.pop().unwrap() + } else { + let loader = accumulators[0].lhs.loader(); + let (lhs, rhs) = accumulators + .iter() + .map(|KzgAccumulator { lhs, rhs }| { + let [lhs, rhs] = [&lhs, &rhs].map(|ec_point| loader.dup_ec_point(ec_point)); + (lhs, rhs) + }) + .unzip::<_, _, Vec<_>, Vec<_>>(); + + let hash_ptr = loader.keccak256(lhs[0].ptr(), lhs.len() * 0x80); + let challenge_ptr = loader.allocate(0x20); + loader + .code_mut() + .push(loader.scalar_modulus()) + .push(hash_ptr) + .mload() + .r#mod() + .push(challenge_ptr) + .mstore(); + let challenge = loader.scalar(Value::Memory(challenge_ptr)); + + let powers_of_challenge = LoadedScalar::::powers(&challenge, lhs.len()); + let [lhs, rhs] = [lhs, rhs].map(|msms| { + msms.into_iter() + .zip(powers_of_challenge.iter()) + .map(|(msm, power_of_challenge)| { + Msm::>::base(msm) * power_of_challenge + }) + .sum::>() + .evaluate(None) + }); + + KzgAccumulator::new(lhs, rhs) + }; + + Self::decide(dk, accumulator) + } + } +} diff --git a/src/pcs/kzg/multiopen.rs b/src/pcs/kzg/multiopen.rs new file mode 100644 index 00000000..d3e50e62 --- /dev/null +++ b/src/pcs/kzg/multiopen.rs @@ -0,0 +1,5 @@ +mod bdfg21; +mod gwc19; + +pub use bdfg21::{Bdfg21, Bdfg21Proof}; +pub use gwc19::{Gwc19, Gwc19Proof}; diff --git a/src/pcs/kzg/multiopen/bdfg21.rs b/src/pcs/kzg/multiopen/bdfg21.rs new file mode 100644 index 00000000..287700d7 --- /dev/null +++ b/src/pcs/kzg/multiopen/bdfg21.rs @@ -0,0 +1,381 @@ +use crate::{ + cost::{Cost, CostEstimation}, + loader::{LoadedScalar, Loader, ScalarLoader}, + pcs::{ + kzg::{Kzg, KzgAccumulator, KzgSuccinctVerifyingKey}, + MultiOpenScheme, Query, + }, + util::{ + arithmetic::{ilog2, CurveAffine, FieldExt, Fraction, MultiMillerLoop}, + msm::Msm, + transcript::TranscriptRead, + Itertools, + }, + Error, +}; +use std::{ + collections::{BTreeMap, BTreeSet}, + marker::PhantomData, +}; + +#[derive(Clone, Debug)] +pub struct Bdfg21; + +impl MultiOpenScheme for Kzg +where + M: MultiMillerLoop, + L: Loader, +{ + type SuccinctVerifyingKey = KzgSuccinctVerifyingKey; + type Proof = Bdfg21Proof; + + fn read_proof( + _: &KzgSuccinctVerifyingKey, + _: &[Query], + transcript: &mut T, + ) -> Result, Error> + where + T: TranscriptRead, + { + Bdfg21Proof::read(transcript) + } + + fn succinct_verify( + svk: &KzgSuccinctVerifyingKey, + commitments: &[Msm], + z: &L::LoadedScalar, + queries: &[Query], + proof: &Bdfg21Proof, + ) -> Result { + let f = { + let sets = query_sets(queries); + let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); + + let powers_of_mu = proof + .mu + .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let msms = sets + .iter() + .zip(coeffs.iter()) + .map(|(set, coeff)| set.msm(coeff, commitments, &powers_of_mu)); + + msms.zip(proof.gamma.powers(sets.len()).into_iter()) + .map(|(msm, power_of_gamma)| msm * &power_of_gamma) + .sum::>() + - Msm::base(proof.w.clone()) * &coeffs[0].z_s + }; + + let rhs = Msm::base(proof.w_prime.clone()); + let lhs = f + rhs.clone() * &proof.z_prime; + + Ok(KzgAccumulator::new( + lhs.evaluate(Some(svk.g)), + rhs.evaluate(Some(svk.g)), + )) + } +} + +#[derive(Clone, Debug)] +pub struct Bdfg21Proof +where + C: CurveAffine, + L: Loader, +{ + mu: L::LoadedScalar, + gamma: L::LoadedScalar, + w: L::LoadedEcPoint, + z_prime: L::LoadedScalar, + w_prime: L::LoadedEcPoint, +} + +impl Bdfg21Proof +where + C: CurveAffine, + L: Loader, +{ + fn read>(transcript: &mut T) -> Result { + let mu = transcript.squeeze_challenge(); + let gamma = transcript.squeeze_challenge(); + let w = transcript.read_ec_point()?; + let z_prime = transcript.squeeze_challenge(); + let w_prime = transcript.read_ec_point()?; + Ok(Bdfg21Proof { + mu, + gamma, + w, + z_prime, + w_prime, + }) + } +} + +fn query_sets(queries: &[Query]) -> Vec> { + let poly_shifts = queries.iter().fold( + Vec::<(usize, Vec, Vec<&T>)>::new(), + |mut poly_shifts, query| { + if let Some(pos) = poly_shifts + .iter() + .position(|(poly, _, _)| *poly == query.poly) + { + let (_, shifts, evals) = &mut poly_shifts[pos]; + if !shifts.contains(&query.shift) { + shifts.push(query.shift); + evals.push(&query.eval); + } + } else { + poly_shifts.push((query.poly, vec![query.shift], vec![&query.eval])); + } + poly_shifts + }, + ); + + poly_shifts.into_iter().fold( + Vec::>::new(), + |mut sets, (poly, shifts, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.shifts + .iter() + .map(|lhs| { + let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + evals[idx].clone() + }) + .collect(), + ); + } + } else { + let set = QuerySet { + shifts, + polys: vec![poly], + evals: vec![evals.into_iter().cloned().collect()], + }; + sets.push(set); + } + sets + }, + ) +} + +fn query_set_coeffs>( + sets: &[QuerySet], + z: &T, + z_prime: &T, +) -> Vec> { + let loader = z.loader(); + + let superset = sets + .iter() + .flat_map(|set| set.shifts.clone()) + .sorted() + .dedup(); + + let size = 2.max( + ilog2((sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1).next_power_of_two()) + 1, + ); + let powers_of_z = z.powers(size); + let z_prime_minus_z_shift_i = BTreeMap::from_iter(superset.map(|shift| { + ( + shift, + z_prime.clone() - z.clone() * loader.load_const(&shift), + ) + })); + + let mut z_s_1 = None; + let mut coeffs = sets + .iter() + .map(|set| { + let coeff = QuerySetCoeff::new( + &set.shifts, + &powers_of_z, + z_prime, + &z_prime_minus_z_shift_i, + &z_s_1, + ); + if z_s_1.is_none() { + z_s_1 = Some(coeff.z_s.clone()); + }; + coeff + }) + .collect_vec(); + + T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + coeffs.iter_mut().for_each(QuerySetCoeff::evaluate); + + coeffs +} + +#[derive(Clone, Debug)] +struct QuerySet { + shifts: Vec, + polys: Vec, + evals: Vec>, +} + +impl> QuerySet { + fn msm>( + &self, + coeff: &QuerySetCoeff, + commitments: &[Msm], + powers_of_mu: &[T], + ) -> Msm { + self.polys + .iter() + .zip(self.evals.iter()) + .zip(powers_of_mu.iter()) + .map(|((poly, evals), power_of_mu)| { + let loader = power_of_mu.loader(); + let commitment = coeff + .commitment_coeff + .as_ref() + .map(|commitment_coeff| { + commitments[*poly].clone() * commitment_coeff.evaluated() + }) + .unwrap_or_else(|| commitments[*poly].clone()); + let r_eval = loader.sum_products( + &coeff + .eval_coeffs + .iter() + .zip(evals.iter()) + .map(|(coeff, eval)| (coeff.evaluated(), eval)) + .collect_vec(), + ) * coeff.r_eval_coeff.as_ref().unwrap().evaluated(); + (commitment - Msm::constant(r_eval)) * power_of_mu + }) + .sum() + } +} + +#[derive(Clone, Debug)] +struct QuerySetCoeff { + z_s: T, + eval_coeffs: Vec>, + commitment_coeff: Option>, + r_eval_coeff: Option>, + _marker: PhantomData, +} + +impl QuerySetCoeff +where + F: FieldExt, + T: LoadedScalar, +{ + fn new( + shifts: &[F], + powers_of_z: &[T], + z_prime: &T, + z_prime_minus_z_shift_i: &BTreeMap, + z_s_1: &Option, + ) -> Self { + let loader = z_prime.loader(); + + let normalized_ell_primes = shifts + .iter() + .enumerate() + .map(|(j, shift_j)| { + shifts + .iter() + .enumerate() + .filter(|&(i, _)| i != j) + .map(|(_, shift_i)| (*shift_j - shift_i)) + .reduce(|acc, value| acc * value) + .unwrap_or_else(|| F::one()) + }) + .collect_vec(); + + let z = &powers_of_z[1].clone(); + let z_pow_k_minus_one = { + let k_minus_one = shifts.len() - 1; + powers_of_z + .iter() + .enumerate() + .skip(1) + .filter_map(|(i, power_of_z)| { + (k_minus_one & (1 << i) == 1).then(|| power_of_z.clone()) + }) + .reduce(|acc, value| acc * value) + .unwrap_or_else(|| loader.load_one()) + }; + + let barycentric_weights = shifts + .iter() + .zip(normalized_ell_primes.iter()) + .map(|(shift, normalized_ell_prime)| { + loader.sum_products_with_coeff(&[ + (*normalized_ell_prime, &z_pow_k_minus_one, z_prime), + (-(*normalized_ell_prime * shift), &z_pow_k_minus_one, z), + ]) + }) + .map(Fraction::one_over) + .collect_vec(); + + let z_s = loader.product( + &shifts + .iter() + .map(|shift| z_prime_minus_z_shift_i.get(shift).unwrap()) + .collect_vec(), + ); + let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone())); + + Self { + z_s, + eval_coeffs: barycentric_weights, + commitment_coeff: z_s_1_over_z_s, + r_eval_coeff: None, + _marker: PhantomData, + } + } + + fn denoms(&mut self) -> impl IntoIterator { + if self.eval_coeffs.first().unwrap().denom().is_some() { + return self + .eval_coeffs + .iter_mut() + .chain(self.commitment_coeff.as_mut()) + .filter_map(Fraction::denom_mut) + .collect_vec(); + } + + if self.r_eval_coeff.is_none() { + let loader = self.z_s.loader(); + self.eval_coeffs + .iter_mut() + .chain(self.commitment_coeff.as_mut()) + .for_each(Fraction::evaluate); + let barycentric_weights_sum = loader.sum( + &self + .eval_coeffs + .iter() + .map(Fraction::evaluated) + .collect_vec(), + ); + self.r_eval_coeff = Some(match self.commitment_coeff.clone() { + Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum), + None => Fraction::one_over(barycentric_weights_sum), + }); + return vec![self.r_eval_coeff.as_mut().unwrap().denom_mut().unwrap()]; + } + + unreachable!() + } + + fn evaluate(&mut self) { + self.r_eval_coeff.as_mut().unwrap().evaluate(); + } +} + +impl CostEstimation for Kzg +where + M: MultiMillerLoop, +{ + type Input = Vec>; + + fn estimate_cost(_: &Vec>) -> Cost { + Cost::new(0, 2, 0, 2) + } +} diff --git a/src/pcs/kzg/multiopen/gwc19.rs b/src/pcs/kzg/multiopen/gwc19.rs new file mode 100644 index 00000000..121fce8a --- /dev/null +++ b/src/pcs/kzg/multiopen/gwc19.rs @@ -0,0 +1,167 @@ +use crate::{ + cost::{Cost, CostEstimation}, + loader::{LoadedScalar, Loader}, + pcs::{ + kzg::{Kzg, KzgAccumulator, KzgSuccinctVerifyingKey}, + MultiOpenScheme, Query, + }, + util::{ + arithmetic::{CurveAffine, MultiMillerLoop, PrimeField}, + msm::Msm, + transcript::TranscriptRead, + Itertools, + }, + Error, +}; + +#[derive(Clone, Debug)] +pub struct Gwc19; + +impl MultiOpenScheme for Kzg +where + M: MultiMillerLoop, + L: Loader, +{ + type SuccinctVerifyingKey = KzgSuccinctVerifyingKey; + type Proof = Gwc19Proof; + + fn read_proof( + _: &Self::SuccinctVerifyingKey, + queries: &[Query], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + Gwc19Proof::read(queries, transcript) + } + + fn succinct_verify( + svk: &Self::SuccinctVerifyingKey, + commitments: &[Msm], + z: &L::LoadedScalar, + queries: &[Query], + proof: &Self::Proof, + ) -> Result { + let sets = query_sets(queries); + let powers_of_u = &proof.u.powers(sets.len()); + let f = { + let powers_of_v = proof + .v + .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + sets.iter() + .map(|set| set.msm(commitments, &powers_of_v)) + .zip(powers_of_u.iter()) + .map(|(msm, power_of_u)| msm * power_of_u) + .sum::>() + }; + let z_omegas = sets + .iter() + .map(|set| z.clone() * &z.loader().load_const(&set.shift)); + + let rhs = proof + .ws + .iter() + .zip(powers_of_u.iter()) + .map(|(w, power_of_u)| Msm::base(w.clone()) * power_of_u) + .collect_vec(); + let lhs = f + rhs + .iter() + .zip(z_omegas) + .map(|(uw, z_omega)| uw.clone() * &z_omega) + .sum(); + + Ok(KzgAccumulator::new( + lhs.evaluate(Some(svk.g)), + rhs.into_iter().sum::>().evaluate(Some(svk.g)), + )) + } +} + +#[derive(Clone, Debug)] +pub struct Gwc19Proof +where + C: CurveAffine, + L: Loader, +{ + v: L::LoadedScalar, + ws: Vec, + u: L::LoadedScalar, +} + +impl Gwc19Proof +where + C: CurveAffine, + L: Loader, +{ + fn read(queries: &[Query], transcript: &mut T) -> Result + where + T: TranscriptRead, + { + let v = transcript.squeeze_challenge(); + let ws = transcript.read_n_ec_points(query_sets(queries).len())?; + let u = transcript.squeeze_challenge(); + Ok(Gwc19Proof { v, ws, u }) + } +} + +struct QuerySet { + shift: F, + polys: Vec, + evals: Vec, +} + +impl QuerySet +where + F: PrimeField, + T: Clone, +{ + fn msm>( + &self, + commitments: &[Msm], + powers_of_v: &[L::LoadedScalar], + ) -> Msm { + self.polys + .iter() + .zip(self.evals.iter()) + .map(|(poly, eval)| { + let commitment = commitments[*poly].clone(); + commitment - Msm::constant(eval.clone()) + }) + .zip(powers_of_v.iter()) + .map(|(msm, power_of_v)| msm * power_of_v) + .sum() + } +} + +fn query_sets(queries: &[Query]) -> Vec> +where + F: PrimeField, + T: Clone + PartialEq, +{ + queries.iter().fold(Vec::new(), |mut sets, query| { + if let Some(pos) = sets.iter().position(|set| set.shift == query.shift) { + sets[pos].polys.push(query.poly); + sets[pos].evals.push(query.eval.clone()); + } else { + sets.push(QuerySet { + shift: query.shift, + polys: vec![query.poly], + evals: vec![query.eval.clone()], + }); + } + sets + }) +} + +impl CostEstimation for Kzg +where + M: MultiMillerLoop, +{ + type Input = Vec>; + + fn estimate_cost(queries: &Vec>) -> Cost { + let num_w = query_sets(queries).len(); + Cost::new(0, num_w, 0, num_w) + } +} diff --git a/src/protocol.rs b/src/protocol.rs deleted file mode 100644 index af591b78..00000000 --- a/src/protocol.rs +++ /dev/null @@ -1,54 +0,0 @@ -use crate::util::{Curve, Domain, Expression, Group, Itertools, Query}; - -#[cfg(feature = "halo2")] -pub mod halo2; - -#[derive(Clone, Debug)] -pub struct Protocol { - pub zk: bool, - pub domain: Domain, - pub preprocessed: Vec, - pub num_statement: Vec, - pub num_auxiliary: Vec, - pub num_challenge: Vec, - pub evaluations: Vec, - pub queries: Vec, - pub relations: Vec>, - pub transcript_initial_state: C::Scalar, - pub accumulator_indices: Option>>, -} - -impl Protocol { - pub fn vanishing_poly(&self) -> usize { - self.preprocessed.len() - + self.num_statement.len() - + self.num_auxiliary.iter().sum::() - } -} - -pub struct Snark { - pub protocol: Protocol, - pub statements: Vec::Scalar>>, - pub proof: Vec, -} - -impl Snark { - pub fn new( - protocol: Protocol, - statements: Vec::Scalar>>, - proof: Vec, - ) -> Self { - assert_eq!( - protocol.num_statement, - statements - .iter() - .map(|statements| statements.len()) - .collect_vec() - ); - Snark { - protocol, - statements, - proof, - } - } -} diff --git a/src/protocol/halo2/test.rs b/src/protocol/halo2/test.rs deleted file mode 100644 index cfbcefef..00000000 --- a/src/protocol/halo2/test.rs +++ /dev/null @@ -1,176 +0,0 @@ -use crate::{ - protocol::halo2::{compile, Config}, - scheme::kzg::{Cost, CostEstimation, PlonkAccumulationScheme}, - util::{CommonPolynomial, Expression, Query}, -}; -use halo2_curves::bn256::{Bn256, Fr, G1}; -use halo2_proofs::{ - arithmetic::FieldExt, - dev::MockProver, - plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey}, - poly::{ - commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier}, - kzg::commitment::KZGCommitmentScheme, - Rotation, VerificationStrategy, - }, - transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, -}; -use rand_chacha::{ - rand_core::{RngCore, SeedableRng}, - ChaCha20Rng, -}; -use std::assert_matches::assert_matches; - -mod circuit; -mod kzg; - -pub use circuit::{ - maingate::{ - MainGateWithPlookup, MainGateWithPlookupConfig, MainGateWithRange, MainGateWithRangeConfig, - }, - standard::StandardPlonk, -}; - -pub fn create_proof_checked<'a, S, C, P, V, VS, TW, TR, EC, R, const ZK: bool>( - params: &'a S::ParamsProver, - pk: &ProvingKey, - circuits: &[C], - instances: &[&[&[S::Scalar]]], - mut rng: R, -) -> Vec -where - S: CommitmentScheme, - S::ParamsVerifier: 'a, - C: Circuit, - P: Prover<'a, S>, - V: Verifier<'a, S>, - VS: VerificationStrategy<'a, S, V, Output = VS>, - TW: TranscriptWriterBuffer, S::Curve, EC>, - TR: TranscriptReadBuffer<&'static [u8], S::Curve, EC>, - EC: EncodedChallenge, - R: RngCore, -{ - for (circuit, instances) in circuits.iter().zip(instances.iter()) { - MockProver::run::<_, ZK>( - params.k(), - circuit, - instances.iter().map(|instance| instance.to_vec()).collect(), - ) - .unwrap() - .assert_satisfied(); - } - - let proof = { - let mut transcript = TW::init(Vec::new()); - create_proof::( - params, - pk, - circuits, - instances, - &mut rng, - &mut transcript, - ) - .unwrap(); - transcript.finalize() - }; - - let accept = { - let params = params.verifier_params(); - let strategy = VS::new(params); - let mut transcript = TR::init(Box::leak(Box::new(proof.clone()))); - verify_proof::<_, _, _, _, _, ZK>(params, pk.get_vk(), strategy, instances, &mut transcript) - .unwrap() - .finalize() - }; - assert!(accept); - - proof -} - -#[test] -fn test_compile_standard_plonk() { - let circuit = StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())); - - let params = kzg::read_or_create_srs::(9); - let vk = keygen_vk::, _, false>(¶ms, &circuit).unwrap(); - let pk = keygen_pk::, _, false>(¶ms, vk, &circuit).unwrap(); - - let protocol = compile::( - pk.get_vk(), - Config { - zk: false, - query_instance: false, - num_instance: vec![1], - num_proof: 1, - accumulator_indices: None, - }, - ); - - let [q_a, q_b, q_c, q_ab, constant, sigma_a, sigma_b, sigma_c, instance, a, b, c, z] = - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12].map(|poly| Query::new(poly, Rotation::cur())); - let z_w = Query::new(12, Rotation::next()); - let t = Query::new(13, Rotation::cur()); - - assert_eq!(protocol.preprocessed.len(), 8); - assert_eq!(protocol.num_statement, vec![1]); - assert_eq!(protocol.num_auxiliary, vec![3, 0, 1]); - assert_eq!(protocol.num_challenge, vec![1, 2, 0]); - assert_eq!( - protocol.evaluations, - vec![a, b, c, q_a, q_b, q_c, q_ab, constant, sigma_a, sigma_b, sigma_c, z, z_w] - ); - assert_eq!( - protocol.queries, - vec![a, b, c, z, z_w, q_a, q_b, q_c, q_ab, constant, sigma_a, sigma_b, sigma_c, t] - ); - assert_eq!( - format!("{:?}", protocol.relations), - format!("{:?}", { - let [q_a, q_b, q_c, q_ab, constant, sigma_a, sigma_b, sigma_c, instance, a, b, c, z, z_w, beta, gamma, l_0, identity, one, k_1, k_2] = - &[ - Expression::Polynomial(q_a), - Expression::Polynomial(q_b), - Expression::Polynomial(q_c), - Expression::Polynomial(q_ab), - Expression::Polynomial(constant), - Expression::Polynomial(sigma_a), - Expression::Polynomial(sigma_b), - Expression::Polynomial(sigma_c), - Expression::Polynomial(instance), - Expression::Polynomial(a), - Expression::Polynomial(b), - Expression::Polynomial(c), - Expression::Polynomial(z), - Expression::Polynomial(z_w), - Expression::Challenge(1), // beta - Expression::Challenge(2), // gamma - Expression::CommonPolynomial(CommonPolynomial::Lagrange(0)), // l_0 - Expression::CommonPolynomial(CommonPolynomial::Identity), // identity - Expression::Constant(Fr::one()), // one - Expression::Constant(Fr::DELTA), // k_1 - Expression::Constant(Fr::DELTA * Fr::DELTA), // k_2 - ]; - - vec![ - q_a * a + q_b * b + q_c * c + q_ab * a * b + constant + instance, - l_0 * (one - z), - z_w * ((a + beta * sigma_a + gamma) - * (b + beta * sigma_b + gamma) - * (c + beta * sigma_c + gamma)) - - z * ((a + beta * one * identity + gamma) - * (b + beta * k_1 * identity + gamma) - * (c + beta * k_2 * identity + gamma)), - ] - }) - ); - - assert_matches!( - PlonkAccumulationScheme::estimate_cost(&protocol), - Cost { - num_commitment: 9, - num_evaluation: 13, - num_msm: 20, - .. - } - ); -} diff --git a/src/protocol/halo2/test/circuit/maingate.rs b/src/protocol/halo2/test/circuit/maingate.rs deleted file mode 100644 index b03a3ff4..00000000 --- a/src/protocol/halo2/test/circuit/maingate.rs +++ /dev/null @@ -1,385 +0,0 @@ -use crate::{protocol::halo2::test::circuit::plookup::PlookupConfig, util::Itertools}; -use halo2_proofs::{ - arithmetic::FieldExt, - circuit::{floor_planner::V1, Chip, Layouter, Value}, - plonk::{Any, Circuit, Column, ConstraintSystem, Error, Fixed}, - poly::Rotation, -}; -use halo2_wrong_ecc::{ - maingate::{ - decompose, AssignedValue, MainGate, MainGateConfig, MainGateInstructions, RangeChip, - RangeConfig, RangeInstructions, RegionCtx, Term, - }, - EccConfig, -}; -use rand::RngCore; -use std::{collections::BTreeMap, iter}; - -#[derive(Clone)] -pub struct MainGateWithRangeConfig { - main_gate_config: MainGateConfig, - range_config: RangeConfig, -} - -impl MainGateWithRangeConfig { - pub fn configure( - meta: &mut ConstraintSystem, - composition_bits: Vec, - overflow_bits: Vec, - ) -> Self { - let main_gate_config = MainGate::::configure(meta); - let range_config = - RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); - MainGateWithRangeConfig { - main_gate_config, - range_config, - } - } - - pub fn ecc_config(&self) -> EccConfig { - EccConfig::new(self.range_config.clone(), self.main_gate_config.clone()) - } - - pub fn load_table(&self, layouter: &mut impl Layouter) -> Result<(), Error> { - let range_chip = RangeChip::::new(self.range_config.clone()); - range_chip.load_table(layouter)?; - Ok(()) - } -} - -#[derive(Clone, Default)] -pub struct MainGateWithRange(Vec); - -impl MainGateWithRange { - pub fn new(inner: Vec) -> Self { - Self(inner) - } - - pub fn rand(mut rng: R) -> Self { - Self::new(vec![F::from(rng.next_u32() as u64)]) - } - - pub fn instances(&self) -> Vec> { - vec![self.0.clone()] - } -} - -impl Circuit for MainGateWithRange { - type Config = MainGateWithRangeConfig; - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self(vec![F::zero()]) - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - MainGateWithRangeConfig::configure(meta, vec![8], vec![4, 7]) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let main_gate = MainGate::new(config.main_gate_config); - let range_chip = RangeChip::new(config.range_config); - range_chip.load_table(&mut layouter)?; - - let a = layouter.assign_region( - || "", - |mut region| { - let mut offset = 0; - let mut ctx = RegionCtx::new(&mut region, &mut offset); - range_chip.decompose(&mut ctx, Value::known(F::from(u64::MAX)), 8, 64)?; - range_chip.decompose(&mut ctx, Value::known(F::from(u32::MAX as u64)), 8, 39)?; - let a = range_chip.assign(&mut ctx, Value::known(self.0[0]), 8, 68)?; - let b = main_gate.sub_sub_with_constant(&mut ctx, &a, &a, &a, F::from(2))?; - let cond = main_gate.assign_bit(&mut ctx, Value::known(F::one()))?; - main_gate.select(&mut ctx, &a, &b, &cond)?; - - Ok(a) - }, - )?; - - main_gate.expose_public(layouter, a, 0)?; - - Ok(()) - } -} - -#[derive(Clone, Debug)] -pub struct PlookupRangeConfig { - main_gate_config: MainGateConfig, - plookup_config: PlookupConfig, - table: [Column; 2], - q_limb: [Column; 2], - q_overflow: [Column; 2], - bits: BTreeMap, -} - -pub struct PlookupRangeChip { - n: usize, - config: PlookupRangeConfig, - main_gate: MainGate, -} - -impl PlookupRangeChip { - pub fn new(config: PlookupRangeConfig, n: usize) -> Self { - let main_gate = MainGate::new(config.main_gate_config.clone()); - Self { - n, - config, - main_gate, - } - } - - pub fn configure( - meta: &mut ConstraintSystem, - main_gate_config: MainGateConfig, - bits: impl IntoIterator, - ) -> PlookupRangeConfig { - let table = [(); 2].map(|_| meta.fixed_column()); - let q_limb = [(); 2].map(|_| meta.fixed_column()); - let q_overflow = [(); 2].map(|_| meta.fixed_column()); - let plookup_config = PlookupConfig::configure( - meta, - |meta| { - let [a, b, c, d, _] = main_gate_config.advices(); - let limbs = [a, b, c, d].map(|column| meta.query_advice(column, Rotation::cur())); - let overflow = meta.query_advice(a, Rotation::cur()); - let q_limb = q_limb.map(|column| meta.query_fixed(column, Rotation::cur())); - let q_overflow = q_overflow.map(|column| meta.query_fixed(column, Rotation::cur())); - iter::empty() - .chain(limbs.into_iter().zip(iter::repeat(q_limb))) - .chain(Some((overflow, q_overflow))) - .map(|(value, [selector, tag])| [tag, selector * value]) - .collect() - }, - table.map(Column::::from), - None, - None, - None, - None, - ); - let bits = bits - .into_iter() - .sorted() - .dedup() - .enumerate() - .map(|(tag, bit)| (bit, tag)) - .collect(); - PlookupRangeConfig { - main_gate_config, - plookup_config, - table, - q_limb, - q_overflow, - bits, - } - } - - pub fn assign_inner(&self, layouter: impl Layouter, n: usize) -> Result<(), Error> { - self.config.plookup_config.assign(layouter, n) - } -} - -impl Chip for PlookupRangeChip { - type Config = PlookupRangeConfig; - - type Loaded = (); - - fn config(&self) -> &Self::Config { - &self.config - } - - fn loaded(&self) -> &Self::Loaded { - &() - } -} - -impl RangeInstructions for PlookupRangeChip { - fn assign( - &self, - ctx: &mut RegionCtx<'_, '_, F>, - value: Value, - limb_bit: usize, - bit: usize, - ) -> Result, Error> { - let (assigned, _) = self.decompose(ctx, value, limb_bit, bit)?; - Ok(assigned) - } - - fn decompose( - &self, - ctx: &mut RegionCtx<'_, '_, F>, - value: Value, - limb_bit: usize, - bit: usize, - ) -> Result<(AssignedValue, Vec>), Error> { - let (num_limbs, overflow) = (bit / limb_bit, bit % limb_bit); - - let num_limbs = num_limbs + if overflow > 0 { 1 } else { 0 }; - let terms = value - .map(|value| decompose(value, num_limbs, limb_bit)) - .transpose_vec(num_limbs) - .into_iter() - .zip((0..num_limbs).map(|i| F::from(2).pow(&[(limb_bit * i) as u64, 0, 0, 0]))) - .map(|(limb, base)| Term::Unassigned(limb, base)) - .collect_vec(); - - self.main_gate - .decompose(ctx, &terms, F::zero(), |ctx, is_last| { - ctx.assign_fixed(|| "", self.config.q_limb[0], F::one())?; - ctx.assign_fixed( - || "", - self.config.q_limb[1], - F::from(*self.config.bits.get(&limb_bit).unwrap() as u64), - )?; - if is_last && overflow != 0 { - ctx.assign_fixed(|| "", self.config.q_overflow[0], F::one())?; - ctx.assign_fixed( - || "", - self.config.q_overflow[1], - F::from(*self.config.bits.get(&limb_bit).unwrap() as u64), - )?; - } - Ok(()) - }) - } - - fn load_table(&self, layouter: &mut impl Layouter) -> Result<(), Error> { - layouter.assign_region( - || "", - |mut region| { - let mut offset = 0; - - for (bit, tag) in self.config.bits.iter() { - let tag = F::from(*tag as u64); - let table_values: Vec = (0..1 << bit).map(|e| F::from(e)).collect(); - for value in table_values.iter() { - region.assign_fixed( - || "table tag", - self.config.table[0], - offset, - || Value::known(tag), - )?; - region.assign_fixed( - || "table value", - self.config.table[1], - offset, - || Value::known(*value), - )?; - offset += 1; - } - } - - for offset in offset..self.n { - region.assign_fixed( - || "table tag", - self.config.table[0], - offset, - || Value::known(F::zero()), - )?; - region.assign_fixed( - || "table value", - self.config.table[1], - offset, - || Value::known(F::zero()), - )?; - } - - Ok(()) - }, - )?; - - Ok(()) - } -} - -#[derive(Clone)] -pub struct MainGateWithPlookupConfig { - main_gate_config: MainGateConfig, - plookup_range_config: PlookupRangeConfig, -} - -impl MainGateWithPlookupConfig { - pub fn configure( - meta: &mut ConstraintSystem, - bits: impl IntoIterator, - ) -> Self { - let main_gate_config = MainGate::configure(meta); - let plookup_range_config = - PlookupRangeChip::configure(meta, main_gate_config.clone(), bits); - - assert_eq!(meta.degree::(), 3); - - MainGateWithPlookupConfig { - main_gate_config, - plookup_range_config, - } - } -} - -#[derive(Clone, Default)] -pub struct MainGateWithPlookup { - n: usize, - inner: Vec, -} - -impl MainGateWithPlookup { - pub fn new(k: u32, inner: Vec) -> Self { - Self { n: 1 << k, inner } - } - - pub fn instances(&self) -> Vec> { - vec![self.inner.clone()] - } -} - -impl Circuit for MainGateWithPlookup { - type Config = MainGateWithPlookupConfig; - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self { - n: self.n, - inner: vec![F::zero()], - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - MainGateWithPlookupConfig::configure(meta, [1, 7, 8]) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let main_gate = MainGate::::new(config.main_gate_config.clone()); - let range_chip = PlookupRangeChip::new(config.plookup_range_config, self.n); - - range_chip.load_table(&mut layouter)?; - range_chip.assign_inner(layouter.namespace(|| ""), self.n)?; - - let a = layouter.assign_region( - || "", - |mut region| { - let mut offset = 0; - let mut ctx = RegionCtx::new(&mut region, &mut offset); - range_chip.decompose(&mut ctx, Value::known(F::from(u64::MAX)), 8, 64)?; - range_chip.decompose(&mut ctx, Value::known(F::from(u32::MAX as u64)), 8, 39)?; - let a = range_chip.assign(&mut ctx, Value::known(self.inner[0]), 8, 68)?; - let b = main_gate.sub_sub_with_constant(&mut ctx, &a, &a, &a, F::from(2))?; - let cond = main_gate.assign_bit(&mut ctx, Value::known(F::one()))?; - main_gate.select(&mut ctx, &a, &b, &cond)?; - - Ok(a) - }, - )?; - - main_gate.expose_public(layouter, a, 0)?; - - Ok(()) - } -} diff --git a/src/protocol/halo2/test/circuit/plookup.rs b/src/protocol/halo2/test/circuit/plookup.rs deleted file mode 100644 index 4e05e076..00000000 --- a/src/protocol/halo2/test/circuit/plookup.rs +++ /dev/null @@ -1,947 +0,0 @@ -use crate::util::{BatchInvert, EitherOrBoth, Field, Itertools}; -use halo2_proofs::{ - arithmetic::FieldExt, - circuit::{Layouter, Value}, - plonk::{ - Advice, Any, Challenge, Column, ConstraintSystem, Error, Expression, FirstPhase, - SecondPhase, Selector, ThirdPhase, VirtualCells, - }, - poly::Rotation, -}; -use std::{collections::BTreeMap, convert::TryFrom, iter, ops::Mul}; - -fn query( - meta: &mut ConstraintSystem, - query_fn: impl FnOnce(&mut VirtualCells<'_, F>) -> T, -) -> T { - let mut tmp = None; - meta.create_gate("", |meta| { - tmp = Some(query_fn(meta)); - Some(Expression::Constant(F::zero())) - }); - tmp.unwrap() -} - -fn first_fit_packing(cap: usize, weights: Vec) -> Vec> { - let mut bins = Vec::<(usize, Vec)>::new(); - - weights.into_iter().enumerate().for_each(|(idx, weight)| { - for (remaining, indices) in bins.iter_mut() { - if *remaining >= weight { - *remaining -= weight; - indices.push(idx); - return; - } - } - bins.push((cap - weight, vec![idx])); - }); - - bins.into_iter().map(|(_, indices)| indices).collect() -} - -fn max_advice_phase(expression: &Expression) -> u8 { - expression.evaluate( - &|_| 0, - &|_| 0, - &|_| 0, - &|query| query.phase(), - &|_| 0, - &|_| 0, - &|a| a, - &|a, b| a.max(b), - &|a, b| a.max(b), - &|a, _| a, - ) -} - -fn min_challenge_phase(expression: &Expression) -> Option { - expression.evaluate( - &|_| None, - &|_| None, - &|_| None, - &|_| None, - &|_| None, - &|challenge| Some(challenge.phase()), - &|a| a, - &|a, b| match (a, b) { - (Some(a), Some(b)) => Some(a.min(b)), - (Some(phase), None) | (None, Some(phase)) => Some(phase), - (None, None) => None, - }, - &|a, b| match (a, b) { - (Some(a), Some(b)) => Some(a.min(b)), - (Some(phase), None) | (None, Some(phase)) => Some(phase), - (None, None) => None, - }, - &|a, _| a, - ) -} - -fn advice_column_in(meta: &mut ConstraintSystem, phase: u8) -> Column { - match phase { - 0 => meta.advice_column_in(FirstPhase), - 1 => meta.advice_column_in(SecondPhase), - 2 => meta.advice_column_in(ThirdPhase), - _ => unreachable!(), - } -} - -fn challenge_usable_after(meta: &mut ConstraintSystem, phase: u8) -> Challenge { - match phase { - 0 => meta.challenge_usable_after(FirstPhase), - 1 => meta.challenge_usable_after(SecondPhase), - 2 => meta.challenge_usable_after(ThirdPhase), - _ => unreachable!(), - } -} - -#[derive(Clone, Debug)] -pub struct ShuffleConfig { - l_0: Selector, - zs: Vec>, - gamma: Option, - lhs: Vec>, - rhs: Vec>, -} - -impl ShuffleConfig { - pub fn configure( - meta: &mut ConstraintSystem, - lhs: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec>, - rhs: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec>, - l_0: Option, - ) -> Self { - let (lhs, rhs, gamma) = { - let (lhs, rhs) = query(meta, |meta| { - let (lhs, rhs) = (lhs(meta), rhs(meta)); - assert_eq!(lhs.len(), rhs.len()); - (lhs, rhs) - }); - let phase = iter::empty() - .chain(lhs.iter()) - .chain(rhs.iter()) - .map(max_advice_phase) - .max() - .unwrap(); - - let gamma = challenge_usable_after(meta, phase); - - (lhs, rhs, gamma) - }; - let lhs_with_gamma = |meta: &mut VirtualCells<'_, F>| { - let gamma = meta.query_challenge(gamma); - lhs.into_iter().zip(iter::repeat(gamma)).collect() - }; - let rhs_with_gamma = |meta: &mut VirtualCells<'_, F>| { - let gamma = meta.query_challenge(gamma); - rhs.into_iter().zip(iter::repeat(gamma)).collect() - }; - let mut config = Self::configure_with_gamma( - meta, - lhs_with_gamma, - rhs_with_gamma, - |_| None, - |_| None, - l_0, - ); - config.gamma = Some(gamma); - config - } - - pub fn configure_with_gamma( - meta: &mut ConstraintSystem, - lhs_with_gamma: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec<(Expression, Expression)>, - rhs_with_gamma: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec<(Expression, Expression)>, - lhs_coeff: impl FnOnce(&mut VirtualCells<'_, F>) -> Option>, - rhs_coeff: impl FnOnce(&mut VirtualCells<'_, F>) -> Option>, - l_0: Option, - ) -> Self { - if ZK { - todo!() - } - - let (lhs_with_gamma, rhs_with_gamma, lhs_coeff, rhs_coeff) = query(meta, |meta| { - let lhs_with_gamma = lhs_with_gamma(meta); - let rhs_with_gamma = rhs_with_gamma(meta); - let lhs_coeff = lhs_coeff(meta); - let rhs_coeff = rhs_coeff(meta); - assert_eq!(lhs_with_gamma.len(), rhs_with_gamma.len()); - - (lhs_with_gamma, rhs_with_gamma, lhs_coeff, rhs_coeff) - }); - - let gamma_phase = iter::empty() - .chain(lhs_with_gamma.iter()) - .chain(rhs_with_gamma.iter()) - .map(|(value, _)| max_advice_phase(value)) - .max() - .unwrap(); - let z_phase = gamma_phase + 1; - assert!(!lhs_with_gamma - .iter() - .any(|(_, gamma)| gamma.degree() != 0 - || min_challenge_phase(gamma).unwrap() < gamma_phase)); - assert!(!rhs_with_gamma - .iter() - .any(|(_, gamma)| gamma.degree() != 0 - || min_challenge_phase(gamma).unwrap() < gamma_phase)); - - let [lhs_bins, rhs_bins] = [&lhs_with_gamma, &rhs_with_gamma].map(|value_with_gamma| { - first_fit_packing( - meta.degree::() - 1, - value_with_gamma - .iter() - .map(|(value, _)| value.degree()) - .collect(), - ) - }); - let num_z = lhs_bins.len().max(rhs_bins.len()); - - let l_0 = l_0.unwrap_or_else(|| meta.selector()); - let zs = iter::repeat_with(|| advice_column_in(meta, z_phase)) - .take(num_z) - .collect_vec(); - - let collect_contribution = |value_with_gamma: Vec<(Expression, Expression)>, - coeff: Option>, - bins: &[Vec]| { - let mut contribution = bins - .iter() - .map(|bin| { - bin.iter() - .map(|idx| value_with_gamma[*idx].clone()) - .map(|(value, gamma)| value + gamma) - .reduce(|acc, expr| acc * expr) - .unwrap() - }) - .collect_vec(); - - if let Some(coeff) = coeff { - contribution[0] = coeff * contribution[0].clone(); - } - - contribution - }; - let lhs = collect_contribution(lhs_with_gamma, lhs_coeff, &lhs_bins); - let rhs = collect_contribution(rhs_with_gamma, rhs_coeff, &rhs_bins); - - meta.create_gate("Shuffle", |meta| { - let l_0 = meta.query_selector(l_0); - let zs = iter::empty() - .chain(zs.iter().cloned().zip(iter::repeat(Rotation::cur()))) - .chain(Some((zs[0], Rotation::next()))) - .map(|(z, at)| meta.query_advice(z, at)) - .collect_vec(); - - let one = Expression::Constant(F::one()); - let z_0 = zs[0].clone(); - - iter::once(l_0 * (one - z_0)).chain( - lhs.clone() - .into_iter() - .zip_longest(rhs.clone()) - .zip(zs.clone().into_iter().zip(zs.into_iter().skip(1))) - .map(|(pair, (z_i, z_j))| match pair { - EitherOrBoth::Left(lhs) => z_i * lhs - z_j, - EitherOrBoth::Right(rhs) => z_i - z_j * rhs, - EitherOrBoth::Both(lhs, rhs) => z_i * lhs - z_j * rhs, - }), - ) - }); - - ShuffleConfig { - l_0, - zs, - gamma: None, - lhs, - rhs, - } - } - - pub fn assign(&self, mut layouter: impl Layouter, n: usize) -> Result<(), Error> { - if ZK { - todo!() - } - - let lhs = self - .lhs - .iter() - .map(|expression| layouter.evaluate_committed(expression)) - .fold(Value::known(Vec::new()), |acc, evaluated| { - acc.zip(evaluated).map(|(mut acc, evaluated)| { - acc.extend(evaluated); - acc - }) - }); - let rhs = self - .rhs - .iter() - .map(|expression| layouter.evaluate_committed(expression)) - .fold(Value::known(Vec::new()), |acc, evaluated| { - acc.zip(evaluated).map(|(mut acc, evaluated)| { - acc.extend(evaluated); - acc - }) - }); - - let z = lhs - .zip(rhs) - .map(|(lhs, mut rhs)| { - rhs.iter_mut().batch_invert(); - - let products = lhs - .into_iter() - .zip_longest(rhs) - .map(|pair| match pair { - EitherOrBoth::Left(value) | EitherOrBoth::Right(value) => value, - EitherOrBoth::Both(lhs, rhs) => lhs * rhs, - }) - .collect_vec(); - - let mut z = vec![F::one()]; - for i in 0..n { - for j in (i..).step_by(n).take(self.zs.len()) { - z.push(products[j] * z.last().unwrap()); - } - } - - let _last = z.pop().unwrap(); - #[cfg(feature = "sanity-check")] - assert_eq!(_last, F::one()); - - z - }) - .transpose_vec(self.zs.len() * n); - - layouter.assign_region( - || "zs", - |mut region| { - self.l_0.enable(&mut region, 0)?; - - let mut z = z.iter(); - for offset in 0..n { - for column in self.zs.iter() { - region.assign_advice(|| "", *column, offset, || *z.next().unwrap())?; - } - } - - Ok(()) - }, - ) - } -} - -fn binomial_coeffs(n: usize) -> Vec { - debug_assert!(n > 0); - - match n { - 1 => vec![1], - _ => { - let last_row = binomial_coeffs(n - 1); - iter::once(0) - .chain(last_row.iter().cloned()) - .zip(last_row.iter().cloned().chain(iter::once(0))) - .map(|(n, m)| n + m) - .collect() - } - } -} - -fn powers>(one: T, base: T) -> impl Iterator { - iter::successors(Some(one), move |power| Some(base.clone() * power.clone())) -} - -fn ordered_multiset(inputs: &[Vec], table: &[F]) -> Vec { - let mut input_counts = inputs - .iter() - .flatten() - .fold(BTreeMap::new(), |mut map, value| { - map.entry(value) - .and_modify(|count| *count += 1) - .or_insert(1); - map - }); - - let mut ordered = Vec::with_capacity((inputs.len() + 1) * inputs[0].len()); - for (count, value) in table.iter().dedup_with_count() { - let count = input_counts - .remove(value) - .map(|input_count| input_count + count) - .unwrap_or(count); - ordered.extend(iter::repeat(*value).take(count)); - } - - #[cfg(feature = "sanity-check")] - { - assert_eq!(input_counts.len(), 0); - assert_eq!(ordered.len(), ordered.capacity()); - } - - ordered.extend(iter::repeat(*ordered.last().unwrap()).take(ordered.capacity() - ordered.len())); - - ordered -} - -#[allow(dead_code)] -#[derive(Clone, Debug)] -pub struct PlookupConfig { - shuffle: ShuffleConfig, - compressed_inputs: Vec>, - compressed_table: Expression, - mixes: Vec>, - theta: Option, - beta: Challenge, - gamma: Challenge, -} - -impl PlookupConfig { - pub fn configure( - meta: &mut ConstraintSystem, - inputs: impl FnOnce(&mut VirtualCells<'_, F>) -> Vec<[Expression; W]>, - table: [Column; W], - l_0: Option, - theta: Option, - beta: Option, - gamma: Option, - ) -> Self { - if ZK { - todo!() - } - - let inputs = query(meta, inputs); - let t = inputs.len(); - let theta_phase = iter::empty() - .chain(inputs.iter().flatten()) - .map(max_advice_phase) - .chain(table.iter().map(|column| { - Column::::try_from(*column) - .map(|column| column.column_type().phase()) - .unwrap_or_default() - })) - .max() - .unwrap(); - let mixes_phase = theta_phase + 1; - - let theta = if W > 1 { - Some(match theta { - Some(theta) => { - assert!(theta.phase() >= theta_phase); - theta - } - None => challenge_usable_after(meta, theta_phase), - }) - } else { - assert!(theta.is_none()); - None - }; - let mixes = iter::repeat_with(|| advice_column_in(meta, mixes_phase)) - .take(t + 1) - .collect_vec(); - let [beta, gamma] = [beta, gamma].map(|challenge| match challenge { - Some(challenge) => { - assert!(challenge.phase() >= mixes_phase); - challenge - } - None => challenge_usable_after(meta, mixes_phase), - }); - assert_ne!(theta, Some(beta)); - assert_ne!(theta, Some(gamma)); - assert_ne!(beta, gamma); - - let (compressed_inputs, compressed_table, compressed_table_w) = query(meta, |meta| { - let [table, table_w] = [Rotation::cur(), Rotation::next()] - .map(|at| table.map(|column| meta.query_any(column, at))); - let theta = theta.map(|theta| meta.query_challenge(theta)); - - let compressed_inputs = inputs - .iter() - .map(|input| { - input - .iter() - .cloned() - .reduce(|acc, expr| acc * theta.clone().unwrap() + expr) - .unwrap() - }) - .collect_vec(); - let compressed_table = table - .iter() - .cloned() - .reduce(|acc, expr| acc * theta.clone().unwrap() + expr) - .unwrap(); - let compressed_table_w = table_w - .iter() - .cloned() - .reduce(|acc, expr| acc * theta.clone().unwrap() + expr) - .unwrap(); - - (compressed_inputs, compressed_table, compressed_table_w) - }); - let lhs_with_gamma = |meta: &mut VirtualCells<'_, F>| { - let [beta, gamma] = [beta, gamma].map(|challenge| meta.query_challenge(challenge)); - let one = Expression::Constant(F::one()); - let gamma_prime = (one + beta.clone()) * gamma.clone(); - - let values = compressed_inputs.clone().into_iter().chain(Some( - compressed_table.clone() + compressed_table_w.clone() * beta, - )); - let gammas = iter::empty() - .chain(iter::repeat(gamma).take(t)) - .chain(Some(gamma_prime)); - values.zip(gammas).collect() - }; - let rhs_with_gamma = |meta: &mut VirtualCells<'_, F>| { - let mixes = iter::empty() - .chain(mixes.iter().cloned().zip(iter::repeat(Rotation::cur()))) - .chain(Some((mixes[0], Rotation::next()))) - .map(|(column, at)| meta.query_advice(column, at)) - .collect_vec(); - let [beta, gamma] = [beta, gamma].map(|challenge| meta.query_challenge(challenge)); - let one = Expression::Constant(F::one()); - let gamma_prime = (one + beta.clone()) * gamma; - - let values = mixes - .iter() - .cloned() - .zip(mixes.iter().skip(1).cloned()) - .zip(iter::repeat(beta)) - .map(|((mix_i, mix_j), beta)| mix_i + mix_j * beta); - let gammas = iter::repeat(gamma_prime).take(t + 1); - values.zip(gammas).collect() - }; - let lhs_coeff = |meta: &mut VirtualCells<'_, F>| { - let beta = meta.query_challenge(beta); - let one = Expression::Constant(F::one()); - binomial_coeffs(t + 1) - .into_iter() - .zip(powers(one, beta)) - .map(|(coeff, power_of_beta)| Expression::Constant(F::from(coeff)) * power_of_beta) - .reduce(|acc, expr| acc + expr) - }; - let shuffle = ShuffleConfig::configure_with_gamma( - meta, - lhs_with_gamma, - rhs_with_gamma, - lhs_coeff, - |_| None, - l_0, - ); - - Self { - shuffle, - compressed_inputs, - compressed_table, - mixes, - theta, - beta, - gamma, - } - } - - pub fn assign(&self, mut layouter: impl Layouter, n: usize) -> Result<(), Error> { - if ZK { - todo!() - } - - let compressed_inputs = self - .compressed_inputs - .iter() - .map(|expression| layouter.evaluate_committed(expression)) - .fold(Value::known(Vec::new()), |acc, compressed_input| { - acc.zip(compressed_input) - .map(|(mut acc, compressed_input)| { - acc.push(compressed_input); - acc - }) - }); - let compressed_table = layouter.evaluate_committed(&self.compressed_table); - - let mix = compressed_inputs - .zip(compressed_table.as_ref()) - .map(|(compressed_inputs, compressed_table)| { - ordered_multiset(&compressed_inputs, compressed_table) - }) - .transpose_vec(self.mixes.len() * n); - - layouter.assign_region( - || "mixes", - |mut region| { - let mut mix = mix.iter(); - for offset in 0..n { - for column in self.mixes.iter() { - region.assign_advice(|| "", *column, offset, || *mix.next().unwrap())?; - } - } - - Ok(()) - }, - )?; - - self.shuffle.assign(layouter.namespace(|| "Shuffle"), n)?; - - Ok(()) - } -} - -#[cfg(test)] -mod test { - use super::{PlookupConfig, ShuffleConfig}; - use crate::util::Itertools; - use halo2_curves::{bn256::Fr, FieldExt}; - use halo2_proofs::{ - circuit::{floor_planner::V1, Layouter, Value}, - dev::{metadata::Constraint, FailureLocation, MockProver, VerifyFailure}, - plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed}, - poly::Rotation, - }; - use rand::{rngs::OsRng, RngCore}; - use std::{iter, mem}; - - fn shuffled( - mut values: [Vec; T], - mut rng: R, - ) -> [Vec; T] { - let n = values[0].len(); - let mut swap = |lhs: usize, rhs: usize| { - let tmp = mem::take(&mut values[lhs / n][lhs % n]); - values[lhs / n][lhs % n] = mem::replace(&mut values[rhs / n][rhs % n], tmp); - }; - - for row in (1..n * T).rev() { - let rand_row = (rng.next_u32() as usize) % row; - swap(row, rand_row); - } - - values - } - - #[derive(Clone)] - pub struct Shuffler { - n: usize, - lhs: Value<[Vec; T]>, - rhs: Value<[Vec; T]>, - } - - impl Shuffler { - pub fn rand(k: u32, mut rng: R) -> Self { - let n = 1 << k; - let lhs = [(); T].map(|_| { - let rng = &mut rng; - iter::repeat_with(|| F::random(&mut *rng)) - .take(n) - .collect_vec() - }); - let rhs = shuffled( - lhs.iter() - .map(|lhs| lhs.iter().map(F::square).collect()) - .collect_vec() - .try_into() - .unwrap(), - rng, - ); - Self { - n, - lhs: Value::known(lhs), - rhs: Value::known(rhs), - } - } - } - - impl Circuit for Shuffler { - type Config = ( - [Column; T], - [Column; T], - ShuffleConfig, - ); - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self { - n: self.n, - lhs: Value::unknown(), - rhs: Value::unknown(), - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let lhs = [(); T].map(|_| meta.advice_column()); - let rhs = [(); T].map(|_| meta.advice_column()); - let shuffle = ShuffleConfig::configure( - meta, - |meta| { - lhs.map(|column| { - let lhs = meta.query_advice(column, Rotation::cur()); - lhs.clone() * lhs - }) - .to_vec() - }, - |meta| { - rhs.map(|column| meta.query_advice(column, Rotation::cur())) - .to_vec() - }, - None, - ); - - (lhs, rhs, shuffle) - } - - fn synthesize( - &self, - (lhs, rhs, shuffle): Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - layouter.assign_region( - || "", - |mut region| { - for (idx, column) in lhs.into_iter().enumerate() { - let values = self.lhs.as_ref().map(|lhs| lhs[idx].clone()); - for (offset, value) in - values.clone().transpose_vec(self.n).into_iter().enumerate() - { - region.assign_advice(|| "", column, offset, || value)?; - } - } - for (idx, column) in rhs.into_iter().enumerate() { - let values = self.rhs.as_ref().map(|rhs| rhs[idx].clone()); - for (offset, value) in - values.clone().transpose_vec(self.n).into_iter().enumerate() - { - region.assign_advice(|| "", column, offset, || value)?; - } - } - Ok(()) - }, - )?; - shuffle.assign(layouter.namespace(|| "Shuffle"), self.n)?; - - Ok(()) - } - } - - #[derive(Clone)] - pub struct Plookuper { - n: usize, - inputs: Value<[Vec<[F; W]>; T]>, - table: Vec<[F; W]>, - } - - impl Plookuper { - pub fn rand(k: u32, mut rng: R) -> Self { - let n = 1 << k; - let m = rng.next_u32() as usize % n; - let mut table = iter::repeat_with(|| [(); W].map(|_| F::random(&mut rng))) - .take(m) - .collect_vec(); - table.extend( - iter::repeat( - table - .first() - .cloned() - .unwrap_or_else(|| [(); W].map(|_| F::random(&mut rng))), - ) - .take(n - m), - ); - let inputs = [(); T].map(|_| { - iter::repeat_with(|| table[rng.next_u32() as usize % n]) - .take(n) - .collect() - }); - Self { - n, - inputs: Value::known(inputs), - table, - } - } - } - - impl Circuit - for Plookuper - { - type Config = ( - [[Column; W]; T], - [Column; W], - PlookupConfig, - ); - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self { - n: self.n, - inputs: Value::unknown(), - table: self.table.clone(), - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let inputs = [(); T].map(|_| [(); W].map(|_| meta.advice_column())); - let table = [(); W].map(|_| meta.fixed_column()); - let plookup = PlookupConfig::configure( - meta, - |meta| { - inputs - .iter() - .map(|input| input.map(|column| meta.query_advice(column, Rotation::cur()))) - .collect() - }, - table.map(|fixed| fixed.into()), - None, - None, - None, - None, - ); - - (inputs, table, plookup) - } - - fn synthesize( - &self, - (inputs, table, plookup): Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - layouter.assign_region( - || "", - |mut region| { - for (offset, value) in self.table.iter().enumerate() { - for (column, value) in table.iter().zip(value.iter()) { - region.assign_fixed(|| "", *column, offset, || Value::known(*value))?; - } - } - Ok(()) - }, - )?; - layouter.assign_region( - || "", - |mut region| { - for (idx, columns) in inputs.iter().enumerate() { - let values = self.inputs.as_ref().map(|inputs| inputs[idx].clone()); - for (offset, value) in values.transpose_vec(self.n).into_iter().enumerate() - { - for (column, value) in columns.iter().zip(value.transpose_array()) { - region.assign_advice(|| "", *column, offset, || value)?; - } - } - } - Ok(()) - }, - )?; - plookup.assign(layouter.namespace(|| "Plookup"), self.n)?; - Ok(()) - } - } - - #[allow(dead_code)] - fn assert_constraint_not_satisfied( - result: Result<(), Vec>, - failures: Vec<(Constraint, FailureLocation)>, - ) { - match result { - Err(expected) => { - assert_eq!( - expected - .into_iter() - .map(|failure| match failure { - VerifyFailure::ConstraintNotSatisfied { - constraint, - location, - .. - } => (constraint, location), - _ => panic!("MockProver::verify has unexpected failure"), - }) - .collect_vec(), - failures - ) - } - Ok(_) => { - panic!("MockProver::verify unexpectedly succeeds") - } - } - } - - #[test] - fn test_shuffle() { - const T: usize = 9; - const ZK: bool = false; - - let k = 9; - let circuit = Shuffler::::rand(k, OsRng); - - let mut cs = ConstraintSystem::default(); - Shuffler::::configure(&mut cs); - assert_eq!(cs.degree::(), 3); - - MockProver::run::<_, ZK>(k, &circuit, Vec::new()) - .unwrap() - .assert_satisfied(); - - #[cfg(not(feature = "sanity-check"))] - { - let n = 1 << k; - let mut circuit = circuit; - circuit.lhs = mem::take(&mut circuit.lhs).map(|mut value| { - value[0][0] += Fr::one(); - value - }); - assert_constraint_not_satisfied( - MockProver::run::<_, ZK>(k, &circuit, Vec::new()) - .unwrap() - .verify(), - vec![( - ( - (2, "Shuffle").into(), - (T * 2).div_ceil(cs.degree::() - 1), - "", - ) - .into(), - FailureLocation::InRegion { - region: (0, "").into(), - offset: n - 1, - }, - )], - ); - } - } - - #[test] - fn test_plookup() { - const W: usize = 2; - const T: usize = 5; - const ZK: bool = false; - - let k = 9; - let circuit = Plookuper::::rand(k, OsRng); - - let mut cs = ConstraintSystem::default(); - Plookuper::::configure(&mut cs); - assert_eq!(cs.degree::(), 3); - - MockProver::run::<_, ZK>(k, &circuit, Vec::new()) - .unwrap() - .assert_satisfied(); - - #[cfg(not(feature = "sanity-check"))] - { - let n = 1 << k; - let mut circuit = circuit; - circuit.inputs = mem::take(&mut circuit.inputs).map(|mut inputs| { - inputs[0][0][0] += Fr::one(); - inputs - }); - assert_constraint_not_satisfied( - MockProver::run::<_, ZK>(k, &circuit, Vec::new()) - .unwrap() - .verify(), - vec![( - ( - (3, "Shuffle").into(), - (T + 1).div_ceil(cs.degree::() - 1), - "", - ) - .into(), - FailureLocation::InRegion { - region: (0, "").into(), - offset: n - 1, - }, - )], - ); - } - } -} diff --git a/src/protocol/halo2/test/kzg.rs b/src/protocol/halo2/test/kzg.rs deleted file mode 100644 index 771e1a70..00000000 --- a/src/protocol/halo2/test/kzg.rs +++ /dev/null @@ -1,232 +0,0 @@ -use crate::{ - protocol::halo2::test::{MainGateWithPlookup, MainGateWithRange}, - util::fe_to_limbs, -}; -use halo2_curves::{pairing::Engine, CurveAffine}; -use halo2_proofs::poly::{ - commitment::{CommitmentScheme, Params, ParamsProver}, - kzg::commitment::{KZGCommitmentScheme, ParamsKZG}, -}; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::{fmt::Debug, fs}; - -mod halo2; -mod native; - -#[cfg(feature = "evm")] -mod evm; - -pub const LIMBS: usize = 4; -pub const BITS: usize = 68; - -pub fn read_or_create_srs(k: u32) -> ParamsKZG { - const DIR: &str = "./src/protocol/halo2/test/kzg/fixture"; - let path = format!("{}/k-{}.srs", DIR, k); - match fs::File::open(path.as_str()) { - Ok(mut file) => ParamsKZG::::read(&mut file).unwrap(), - Err(_) => { - fs::create_dir_all(DIR).unwrap(); - let params = - KZGCommitmentScheme::::new_params(k, ChaCha20Rng::from_seed(Default::default())); - let mut file = fs::File::create(path.as_str()).unwrap(); - params.write(&mut file).unwrap(); - params - } - } -} - -pub fn main_gate_with_range_with_mock_kzg_accumulator( -) -> MainGateWithRange { - let g = read_or_create_srs::(3).get_g(); - let [g1, s_g1] = [g[0], g[1]].map(|point| point.coordinates().unwrap()); - MainGateWithRange::new( - [*s_g1.x(), *s_g1.y(), *g1.x(), *g1.y()] - .iter() - .cloned() - .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) - .collect(), - ) -} - -pub fn main_gate_with_plookup_with_mock_kzg_accumulator( - k: u32, -) -> MainGateWithPlookup { - let g = read_or_create_srs::(3).get_g(); - let [g1, s_g1] = [g[0], g[1]].map(|point| point.coordinates().unwrap()); - MainGateWithPlookup::new( - k, - [*s_g1.x(), *s_g1.y(), *g1.x(), *g1.y()] - .iter() - .cloned() - .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) - .collect(), - ) -} - -#[macro_export] -macro_rules! halo2_kzg_config { - ($zk:expr, $num_proof:expr) => { - $crate::protocol::halo2::Config { - zk: $zk, - query_instance: false, - num_instance: Vec::new(), - num_proof: $num_proof, - accumulator_indices: None, - } - }; - ($zk:expr, $num_proof:expr, $accumulator_indices:expr) => { - $crate::protocol::halo2::Config { - zk: $zk, - query_instance: false, - num_instance: Vec::new(), - num_proof: $num_proof, - accumulator_indices: Some($accumulator_indices), - } - }; -} - -#[macro_export] -macro_rules! halo2_kzg_prepare { - ($k:expr, $config:expr, $create_circuit:expr) => {{ - use $crate::{ - protocol::halo2::{compile, test::kzg::read_or_create_srs}, - util::{GroupEncoding, Itertools}, - }; - use halo2_curves::bn256::{Bn256, G1}; - use halo2_proofs::{ - plonk::{keygen_pk, keygen_vk}, - poly::kzg::commitment::KZGCommitmentScheme, - }; - use std::{iter}; - - let circuits = iter::repeat_with(|| $create_circuit) - .take($config.num_proof) - .collect_vec(); - - let params = read_or_create_srs::($k); - let pk = if $config.zk { - let vk = keygen_vk::, _, true>(¶ms, &circuits[0]).unwrap(); - let pk = keygen_pk::, _, true>(¶ms, vk, &circuits[0]).unwrap(); - pk - } else { - let vk = keygen_vk::, _, false>(¶ms, &circuits[0]).unwrap(); - let pk = keygen_pk::, _, false>(¶ms, vk, &circuits[0]).unwrap(); - pk - }; - - let mut config = $config; - config.num_instance = circuits[0].instances().iter().map(|instances| instances.len()).collect(); - let protocol = compile::(pk.get_vk(), config); - assert_eq!( - protocol.preprocessed.len(), - protocol.preprocessed - .iter() - .map(|ec_point| <[u8; 32]>::try_from(ec_point.to_bytes().as_ref().to_vec()).unwrap()) - .unique() - .count() - ); - - (params, pk, protocol, circuits) - }}; -} - -#[macro_export] -macro_rules! halo2_kzg_create_snark { - ($params:expr, $pk:expr, $protocol:expr, $circuits:expr, $prover:ty, $verifier:ty, $verification_strategy:ty, $transcript_read:ty, $transcript_write:ty, $encoded_challenge:ty) => {{ - use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; - use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - use $crate::{ - collect_slice, - protocol::{halo2::test::create_proof_checked, Snark}, - util::Itertools, - }; - - let instances = $circuits - .iter() - .map(|circuit| circuit.instances()) - .collect_vec(); - let proof = { - collect_slice!(instances, 2); - #[allow(clippy::needless_borrow)] - if $protocol.zk { - create_proof_checked::< - KZGCommitmentScheme<_>, - _, - $prover, - $verifier, - $verification_strategy, - $transcript_read, - $transcript_write, - $encoded_challenge, - _, - true, - >( - $params, - $pk, - $circuits, - &instances, - &mut ChaCha20Rng::from_seed(Default::default()), - ) - } else { - create_proof_checked::< - KZGCommitmentScheme<_>, - _, - $prover, - $verifier, - $verification_strategy, - $transcript_read, - $transcript_write, - $encoded_challenge, - _, - false, - >( - $params, - $pk, - $circuits, - &instances, - &mut ChaCha20Rng::from_seed(Default::default()), - ) - } - }; - - Snark::new( - $protocol.clone(), - instances.into_iter().flatten().collect_vec(), - proof, - ) - }}; -} - -#[macro_export] -macro_rules! halo2_kzg_native_accumulate { - ($protocol:expr, $statements:expr, $scheme:ty, $transcript:expr, $stretagy:expr) => {{ - use $crate::{loader::native::NativeLoader, scheme::kzg::AccumulationScheme}; - - <$scheme>::accumulate( - $protocol, - &NativeLoader, - $statements, - $transcript, - $stretagy, - ) - .unwrap(); - }}; -} - -#[macro_export] -macro_rules! halo2_kzg_native_verify { - ($params:ident, $protocol:expr, $statements:expr, $scheme:ty, $transcript:expr) => {{ - use halo2_curves::bn256::Bn256; - use halo2_proofs::poly::commitment::ParamsProver; - use $crate::{ - halo2_kzg_native_accumulate, - protocol::halo2::test::kzg::{BITS, LIMBS}, - scheme::kzg::SameCurveAccumulation, - }; - - let mut stretagy = SameCurveAccumulation::<_, _, LIMBS, BITS>::default(); - halo2_kzg_native_accumulate!($protocol, $statements, $scheme, $transcript, &mut stretagy); - - assert!(stretagy.decide::($params.get_g()[0], $params.g2(), $params.s_g2())); - }}; -} diff --git a/src/protocol/halo2/test/kzg/evm.rs b/src/protocol/halo2/test/kzg/evm.rs deleted file mode 100644 index 6caf0e68..00000000 --- a/src/protocol/halo2/test/kzg/evm.rs +++ /dev/null @@ -1,168 +0,0 @@ -use crate::{ - halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_evm_verify, halo2_kzg_native_verify, - halo2_kzg_prepare, - loader::evm::EvmTranscript, - protocol::halo2::{ - test::{ - kzg::{ - halo2::Accumulation, main_gate_with_plookup_with_mock_kzg_accumulator, - main_gate_with_range_with_mock_kzg_accumulator, LIMBS, - }, - StandardPlonk, - }, - util::evm::ChallengeEvm, - }, - scheme::kzg::PlonkAccumulationScheme, -}; -use halo2_proofs::poly::kzg::{ - multiopen::{ProverGWC, VerifierGWC}, - strategy::AccumulatorStrategy, -}; -use paste::paste; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; - -#[macro_export] -macro_rules! halo2_kzg_evm_verify { - ($params:expr, $protocol:expr, $statements:expr, $proof:expr, $scheme:ty) => {{ - use halo2_curves::bn256::{Fq, Fr}; - use halo2_proofs::poly::commitment::ParamsProver; - use std::{iter, rc::Rc}; - use $crate::{ - loader::evm::{encode_calldata, execute, EvmLoader, EvmTranscript}, - protocol::halo2::test::kzg::{BITS, LIMBS}, - scheme::kzg::{AccumulationScheme, SameCurveAccumulation}, - util::{Itertools, TranscriptRead}, - }; - - let loader = EvmLoader::new::(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); - let statements = $statements - .iter() - .map(|instance| { - iter::repeat_with(|| transcript.read_scalar().unwrap()) - .take(instance.len()) - .collect_vec() - }) - .collect_vec(); - let mut strategy = SameCurveAccumulation::<_, _, LIMBS, BITS>::default(); - <$scheme>::accumulate( - $protocol, - &loader, - statements, - &mut transcript, - &mut strategy, - ) - .unwrap(); - let code = strategy.code($params.get_g()[0], $params.g2(), $params.s_g2()); - let (accept, total_cost, costs) = execute(code, encode_calldata($statements, $proof)); - loader.print_gas_metering(costs); - println!("Total: {}", total_cost); - assert!(accept); - }}; -} - -macro_rules! test { - (@ #[$($attr:meta),*], $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - paste! { - $(#[$attr])* - fn []() { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - $k, - $config, - $create_circuit - ); - let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverGWC<_>, - VerifierGWC<_>, - AccumulatorStrategy<_>, - EvmTranscript<_, _, _, _>, - EvmTranscript<_, _, _, _>, - ChallengeEvm<_> - ); - halo2_kzg_native_verify!( - params, - &snark.protocol, - snark.statements.clone(), - PlonkAccumulationScheme, - &mut EvmTranscript::<_, NativeLoader, _, _>::new(snark.proof.as_slice()) - ); - halo2_kzg_evm_verify!( - params, - &snark.protocol, - snark.statements, - snark.proof, - PlonkAccumulationScheme - ); - } - } - }; - ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test], $name, $k, $config, $create_circuit); - }; - (#[ignore = $reason:literal], $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test, ignore = $reason], $name, $k, $config, $create_circuit); - }; -} - -test!( - zk_standard_plonk_rand, - 9, - halo2_kzg_config!(true, 1), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) -); -test!( - zk_main_gate_with_range_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_range_with_mock_kzg_accumulator::() -); -test!( - #[ignore = "cause it requires 16GB memory to run"], - zk_accumulation_two_snark, - 21, - halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - Accumulation::two_snark(true) -); -test!( - #[ignore = "cause it requires 32GB memory to run"], - zk_accumulation_two_snark_with_accumulator, - 22, - halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - Accumulation::two_snark_with_accumulator(true) -); -test!( - standard_plonk_rand, - 9, - halo2_kzg_config!(false, 1), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) -); -test!( - main_gate_with_range_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_range_with_mock_kzg_accumulator::() -); -test!( - main_gate_with_plookup_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_plookup_with_mock_kzg_accumulator::(9) -); -test!( - #[ignore = "cause it requires 16GB memory to run"], - accumulation_two_snark, - 21, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - Accumulation::two_snark(false) -); -test!( - #[ignore = "cause it requires 32GB memory to run"], - accumulation_two_snark_with_accumulator, - 22, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - Accumulation::two_snark_with_accumulator(false) -); diff --git a/src/protocol/halo2/test/kzg/halo2.rs b/src/protocol/halo2/test/kzg/halo2.rs deleted file mode 100644 index cd8e884b..00000000 --- a/src/protocol/halo2/test/kzg/halo2.rs +++ /dev/null @@ -1,380 +0,0 @@ -use crate::{ - collect_slice, halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_accumulate, - halo2_kzg_native_verify, halo2_kzg_prepare, - loader::{halo2, native::NativeLoader}, - protocol::{ - halo2::{ - test::{ - kzg::{BITS, LIMBS}, - MainGateWithRange, MainGateWithRangeConfig, StandardPlonk, - }, - util::halo2::ChallengeScalar, - }, - Protocol, Snark, - }, - scheme::kzg::{self, AccumulationScheme, ShplonkAccumulationScheme}, - util::{fe_to_limbs, Curve, Group, Itertools, PrimeCurveAffine}, -}; -use halo2_curves::bn256::{Fr, G1Affine, G1}; -use halo2_proofs::{ - circuit::{floor_planner::V1, Layouter, Value}, - plonk, - plonk::Circuit, - poly::{ - commitment::ParamsProver, - kzg::{ - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::AccumulatorStrategy, - }, - }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, -}; -use halo2_wrong_ecc::{self, maingate::RegionCtx}; -use halo2_wrong_transcript::NativeRepresentation; -use paste::paste; -use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::rc::Rc; - -const T: usize = 5; -const RATE: usize = 4; -const R_F: usize = 8; -const R_P: usize = 57; - -type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; -type Halo2Loader<'a, 'b, C> = halo2::Halo2Loader<'a, 'b, C, LIMBS, BITS>; -type PoseidonTranscript = - halo2::PoseidonTranscript; -type SameCurveAccumulation = kzg::SameCurveAccumulation; - -pub struct SnarkWitness { - protocol: Protocol, - statements: Vec::Scalar>>>, - proof: Value>, -} - -impl From> for SnarkWitness { - fn from(snark: Snark) -> Self { - Self { - protocol: snark.protocol, - statements: snark - .statements - .into_iter() - .map(|statements| statements.into_iter().map(Value::known).collect_vec()) - .collect(), - proof: Value::known(snark.proof), - } - } -} - -impl SnarkWitness { - pub fn without_witnesses(&self) -> Self { - SnarkWitness { - protocol: self.protocol.clone(), - statements: self - .statements - .iter() - .map(|statements| vec![Value::unknown(); statements.len()]) - .collect(), - proof: Value::unknown(), - } - } -} - -pub fn accumulate<'a, 'b>( - loader: &Rc>, - stretagy: &mut SameCurveAccumulation>>, - snark: &SnarkWitness, -) -> Result<(), plonk::Error> { - let mut transcript = PoseidonTranscript::<_, Rc>, _, _>::new( - loader, - snark.proof.as_ref().map(|proof| proof.as_slice()), - ); - let statements = snark - .statements - .iter() - .map(|statements| { - statements - .iter() - .map(|statement| loader.assign_scalar(*statement)) - .collect_vec() - }) - .collect_vec(); - ShplonkAccumulationScheme::accumulate( - &snark.protocol, - loader, - statements, - &mut transcript, - stretagy, - ) - .map_err(|_| plonk::Error::Synthesis)?; - Ok(()) -} - -pub struct Accumulation { - g1: G1Affine, - snarks: Vec>, - instances: Vec, -} - -impl Accumulation { - pub fn accumulator_indices() -> Vec<(usize, usize)> { - (0..4 * LIMBS).map(|idx| (0, idx)).collect() - } - - pub fn two_snark(zk: bool) -> Self { - const K: u32 = 9; - - let (params, snark1) = { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - K, - halo2_kzg_config!(zk, 1), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) - ); - let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - AccumulatorStrategy<_>, - PoseidonTranscript<_, _, _, _>, - PoseidonTranscript<_, _, _, _>, - ChallengeScalar<_> - ); - (params, snark) - }; - let snark2 = { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - K, - halo2_kzg_config!(zk, 1), - MainGateWithRange::<_>::rand(ChaCha20Rng::from_seed(Default::default())) - ); - halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - AccumulatorStrategy<_>, - PoseidonTranscript<_, _, _, _>, - PoseidonTranscript<_, _, _, _>, - ChallengeScalar<_> - ) - }; - - let mut strategy = SameCurveAccumulation::::default(); - halo2_kzg_native_accumulate!( - &snark1.protocol, - snark1.statements.clone(), - ShplonkAccumulationScheme, - &mut PoseidonTranscript::::init(snark1.proof.as_slice()), - &mut strategy - ); - halo2_kzg_native_accumulate!( - &snark2.protocol, - snark2.statements.clone(), - ShplonkAccumulationScheme, - &mut PoseidonTranscript::::init(snark2.proof.as_slice()), - &mut strategy - ); - - let g1 = params.get_g()[0]; - let accumulator = strategy.finalize(g1.to_curve()); - let instances = [ - accumulator.0.to_affine().x, - accumulator.0.to_affine().y, - accumulator.1.to_affine().x, - accumulator.1.to_affine().y, - ] - .map(fe_to_limbs::<_, _, LIMBS, BITS>) - .concat(); - - Self { - g1, - snarks: vec![snark1.into(), snark2.into()], - instances, - } - } - - pub fn two_snark_with_accumulator(zk: bool) -> Self { - const K: u32 = 21; - - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - K, - halo2_kzg_config!(zk, 2, Self::accumulator_indices()), - Self::two_snark(zk) - ); - let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - AccumulatorStrategy<_>, - PoseidonTranscript<_, _, _, _>, - PoseidonTranscript<_, _, _, _>, - ChallengeScalar<_> - ); - - let mut strategy = SameCurveAccumulation::::default(); - halo2_kzg_native_accumulate!( - &snark.protocol, - snark.statements.clone(), - ShplonkAccumulationScheme, - &mut PoseidonTranscript::::init(snark.proof.as_slice()), - &mut strategy - ); - - let g1 = params.get_g()[0]; - let accumulator = strategy.finalize(g1.to_curve()); - let instances = [ - accumulator.0.to_affine().x, - accumulator.0.to_affine().y, - accumulator.1.to_affine().x, - accumulator.1.to_affine().y, - ] - .map(fe_to_limbs::<_, _, LIMBS, BITS>) - .concat(); - - Self { - g1, - snarks: vec![snark.into()], - instances, - } - } - - pub fn instances(&self) -> Vec> { - vec![self.instances.clone()] - } -} - -impl Circuit for Accumulation { - type Config = MainGateWithRangeConfig; - type FloorPlanner = V1; - - fn without_witnesses(&self) -> Self { - Self { - g1: self.g1, - snarks: self - .snarks - .iter() - .map(SnarkWitness::without_witnesses) - .collect(), - instances: Vec::new(), - } - } - - fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - MainGateWithRangeConfig::configure::( - meta, - vec![BITS / LIMBS], - BaseFieldEccChip::::rns().overflow_lengths(), - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), plonk::Error> { - config.load_table(&mut layouter)?; - - let (lhs, rhs) = layouter.assign_region( - || "", - |mut region| { - let mut offset = 0; - let ctx = RegionCtx::new(&mut region, &mut offset); - - let loader = Halo2Loader::::new(config.ecc_config(), ctx); - let mut stretagy = SameCurveAccumulation::default(); - for snark in self.snarks.iter() { - accumulate(&loader, &mut stretagy, snark)?; - } - let (lhs, rhs) = stretagy.finalize(self.g1); - - loader.print_row_metering(); - println!("Total: {}", offset); - - Ok((lhs, rhs)) - }, - )?; - - let ecc_chip = BaseFieldEccChip::::new(config.ecc_config()); - ecc_chip.expose_public(layouter.namespace(|| ""), lhs, 0)?; - ecc_chip.expose_public(layouter.namespace(|| ""), rhs, 2 * LIMBS)?; - - Ok(()) - } -} - -macro_rules! test { - (@ #[$($attr:meta),*], $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - paste! { - $(#[$attr])* - fn []() { - let (params, pk, protocol, circuits) = halo2_kzg_prepare!( - $k, - $config, - $create_circuit - ); - let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, - ProverSHPLONK<_>, - VerifierSHPLONK<_>, - AccumulatorStrategy<_>, - Blake2bWrite<_, _, _>, - Blake2bRead<_, _, _>, - Challenge255<_> - ); - halo2_kzg_native_verify!( - params, - &snark.protocol, - snark.statements, - ShplonkAccumulationScheme, - &mut Blake2bRead::<_, G1Affine, _>::init(snark.proof.as_slice()) - ); - } - } - }; - ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test], $name, $k, $config, $create_circuit); - }; - (#[ignore = $reason:literal], $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { - test!(@ #[test, ignore = $reason], $name, $k, $config, $create_circuit); - }; -} - -test!( - #[ignore = "cause it requires 16GB memory to run"], - zk_accumulation_two_snark, - 21, - halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), - Accumulation::two_snark(true) -); -test!( - #[ignore = "cause it requires 32GB memory to run"], - zk_accumulation_two_snark_with_accumulator, - 22, - halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), - Accumulation::two_snark_with_accumulator(true) -); -test!( - #[ignore = "cause it requires 16GB memory to run"], - accumulation_two_snark, - 21, - halo2_kzg_config!(false, 1, Accumulation::accumulator_indices()), - Accumulation::two_snark(false) -); -test!( - #[ignore = "cause it requires 32GB memory to run"], - accumulation_two_snark_with_accumulator, - 22, - halo2_kzg_config!(false, 1, Accumulation::accumulator_indices()), - Accumulation::two_snark_with_accumulator(false) -); diff --git a/src/protocol/halo2/util.rs b/src/protocol/halo2/util.rs deleted file mode 100644 index 437d59cb..00000000 --- a/src/protocol/halo2/util.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::{ - loader::native::NativeLoader, - util::{ - Curve, Itertools, PrimeCurveAffine, PrimeField, Transcript, TranscriptRead, - UncompressedEncoding, - }, - Error, -}; -use halo2_proofs::{ - arithmetic::{CurveAffine, CurveExt}, - transcript::{Blake2bRead, Challenge255}, -}; -use std::{io::Read, iter}; - -pub mod halo2; - -#[cfg(feature = "evm")] -pub mod evm; - -impl UncompressedEncoding for C -where - ::Base: PrimeField, -{ - type Uncompressed = [u8; 64]; - - fn to_uncompressed(&self) -> [u8; 64] { - let coordinates = self.to_affine().coordinates().unwrap(); - iter::empty() - .chain(coordinates.x().to_repr().as_ref()) - .chain(coordinates.y().to_repr().as_ref()) - .cloned() - .collect_vec() - .try_into() - .unwrap() - } - - fn from_uncompressed(uncompressed: [u8; 64]) -> Option { - let x = Option::from(::Base::from_repr( - uncompressed[..32].to_vec().try_into().unwrap(), - ))?; - let y = Option::from(::Base::from_repr( - uncompressed[32..].to_vec().try_into().unwrap(), - ))?; - C::AffineExt::from_xy(x, y) - .map(|ec_point| ec_point.to_curve()) - .into() - } -} - -impl Transcript - for Blake2bRead> -{ - fn squeeze_challenge(&mut self) -> C::Scalar { - *halo2_proofs::transcript::Transcript::squeeze_challenge_scalar::(self) - } - - fn common_ec_point(&mut self, ec_point: &C::CurveExt) -> Result<(), Error> { - halo2_proofs::transcript::Transcript::common_point(self, ec_point.to_affine()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string())) - } - - fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { - halo2_proofs::transcript::Transcript::common_scalar(self, *scalar) - .map_err(|err| Error::Transcript(err.kind(), err.to_string())) - } -} - -impl TranscriptRead - for Blake2bRead> -{ - fn read_scalar(&mut self) -> Result { - halo2_proofs::transcript::TranscriptRead::read_scalar(self) - .map_err(|err| Error::Transcript(err.kind(), err.to_string())) - } - - fn read_ec_point(&mut self) -> Result { - halo2_proofs::transcript::TranscriptRead::read_point(self) - .map(|ec_point| ec_point.to_curve()) - .map_err(|err| Error::Transcript(err.kind(), err.to_string())) - } -} diff --git a/src/protocol/halo2/util/evm.rs b/src/protocol/halo2/util/evm.rs deleted file mode 100644 index c8da8fe9..00000000 --- a/src/protocol/halo2/util/evm.rs +++ /dev/null @@ -1,142 +0,0 @@ -use crate::{ - loader::{ - evm::{u256_to_field, EvmTranscript}, - native::NativeLoader, - }, - util::{self, Curve, PrimeField, UncompressedEncoding}, - Error, -}; -use ethereum_types::U256; -use halo2_curves::{Coordinates, CurveAffine}; -use halo2_proofs::transcript::{ - EncodedChallenge, Transcript, TranscriptRead, TranscriptReadBuffer, TranscriptWrite, - TranscriptWriterBuffer, -}; -use std::io::{self, Read, Write}; - -pub struct ChallengeEvm(C::Scalar) -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField; - -impl EncodedChallenge for ChallengeEvm -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - type Input = [u8; 32]; - - fn new(challenge_input: &[u8; 32]) -> Self { - ChallengeEvm(u256_to_field(U256::from_big_endian(challenge_input))) - } - - fn get_scalar(&self) -> C::Scalar { - self.0 - } -} - -impl Transcript> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn squeeze_challenge(&mut self) -> ChallengeEvm { - ChallengeEvm(util::Transcript::squeeze_challenge(self)) - } - - fn common_point(&mut self, ec_point: C) -> io::Result<()> { - match util::Transcript::common_ec_point(self, &ec_point.to_curve()) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } - - fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - match util::Transcript::common_scalar(self, &scalar) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } -} - -impl TranscriptRead> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn read_point(&mut self) -> io::Result { - match util::TranscriptRead::read_ec_point(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value.to_affine()), - } - } - - fn read_scalar(&mut self) -> io::Result { - match util::TranscriptRead::read_scalar(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value), - } - } -} - -impl TranscriptReadBuffer> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn init(reader: R) -> Self { - Self::new(reader) - } -} - -impl TranscriptWrite> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn write_point(&mut self, ec_point: C) -> io::Result<()> { - Transcript::>::common_point(self, ec_point)?; - let coords: Coordinates = Option::from(ec_point.coordinates()).ok_or_else(|| { - io::Error::new( - io::ErrorKind::Other, - "Cannot write points at infinity to the transcript", - ) - })?; - let mut x = coords.x().to_repr(); - let mut y = coords.y().to_repr(); - x.as_mut().reverse(); - y.as_mut().reverse(); - self.stream_mut().write_all(x.as_ref())?; - self.stream_mut().write_all(y.as_ref()) - } - - fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - Transcript::>::common_scalar(self, scalar)?; - let mut data = scalar.to_repr(); - data.as_mut().reverse(); - self.stream_mut().write_all(data.as_ref()) - } -} - -impl TranscriptWriterBuffer> - for EvmTranscript> -where - C::CurveExt: Curve + UncompressedEncoding, - C::Scalar: PrimeField, -{ - fn init(writer: W) -> Self { - Self::new(writer) - } - - fn finalize(self) -> W { - self.finalize() - } -} diff --git a/src/protocol/halo2/util/halo2.rs b/src/protocol/halo2/util/halo2.rs deleted file mode 100644 index dfda15a8..00000000 --- a/src/protocol/halo2/util/halo2.rs +++ /dev/null @@ -1,212 +0,0 @@ -use crate::{ - loader::{halo2::PoseidonTranscript, native::NativeLoader}, - util::{self, Curve, PrimeField}, - Error, -}; -use halo2_curves::CurveAffine; -use halo2_proofs::transcript::{ - EncodedChallenge, Transcript, TranscriptRead, TranscriptReadBuffer, TranscriptWrite, - TranscriptWriterBuffer, -}; -use halo2_wrong_transcript::NativeRepresentation; -use poseidon::Poseidon; -use std::io::{self, Read, Write}; - -pub struct ChallengeScalar(C::Scalar); - -impl EncodedChallenge for ChallengeScalar { - type Input = C::Scalar; - - fn new(challenge_input: &C::Scalar) -> Self { - ChallengeScalar(*challenge_input) - } - - fn get_scalar(&self) -> C::Scalar { - self.0 - } -} - -impl< - C: CurveAffine, - S, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > Transcript> - for PoseidonTranscript< - C, - NativeLoader, - S, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn squeeze_challenge(&mut self) -> ChallengeScalar { - ChallengeScalar::new(&util::Transcript::squeeze_challenge(self)) - } - - fn common_point(&mut self, ec_point: C) -> io::Result<()> { - match util::Transcript::common_ec_point(self, &ec_point.to_curve()) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } - - fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - match util::Transcript::common_scalar(self, &scalar) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - _ => Ok(()), - } - } -} - -impl< - C: CurveAffine, - R: Read, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead> - for PoseidonTranscript< - C, - NativeLoader, - R, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn read_point(&mut self) -> io::Result { - match util::TranscriptRead::read_ec_point(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value.to_affine()), - } - } - - fn read_scalar(&mut self) -> io::Result { - match util::TranscriptRead::read_scalar(self) { - Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), - Err(_) => unreachable!(), - Ok(value) => Ok(value), - } - } -} - -impl< - C: CurveAffine, - R: Read, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptReadBuffer> - for PoseidonTranscript< - C, - NativeLoader, - R, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn init(reader: R) -> Self { - Self::new(reader) - } -} - -impl< - C: CurveAffine, - W: Write, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptWrite> - for PoseidonTranscript< - C, - NativeLoader, - W, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn write_point(&mut self, ec_point: C) -> io::Result<()> { - Transcript::>::common_point(self, ec_point)?; - let data = ec_point.to_bytes(); - self.stream_mut().write_all(data.as_ref()) - } - - fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { - Transcript::>::common_scalar(self, scalar)?; - let data = scalar.to_repr(); - self.stream_mut().write_all(data.as_ref()) - } -} - -impl< - C: CurveAffine, - W: Write, - const LIMBS: usize, - const BITS: usize, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptWriterBuffer> - for PoseidonTranscript< - C, - NativeLoader, - W, - Poseidon, - NativeRepresentation, - LIMBS, - BITS, - T, - RATE, - R_F, - R_P, - > -{ - fn init(writer: W) -> Self { - Self::new(writer) - } - - fn finalize(self) -> W { - self.finalize() - } -} diff --git a/src/scheme.rs b/src/scheme.rs deleted file mode 100644 index d883d742..00000000 --- a/src/scheme.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod kzg; diff --git a/src/scheme/kzg.rs b/src/scheme/kzg.rs deleted file mode 100644 index e1f949ab..00000000 --- a/src/scheme/kzg.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::{ - protocol::Protocol, - util::{Curve, Expression}, -}; - -mod accumulation; -mod cost; -mod msm; - -pub use accumulation::{ - plonk::PlonkAccumulationScheme, shplonk::ShplonkAccumulationScheme, AccumulationScheme, - AccumulationStrategy, Accumulator, SameCurveAccumulation, -}; -pub use cost::{Cost, CostEstimation}; -pub use msm::MSM; - -pub fn langranges( - protocol: &Protocol, - statements: &[Vec], -) -> impl IntoIterator { - protocol - .relations - .iter() - .cloned() - .sum::>() - .used_langrange() - .into_iter() - .chain( - 0..statements - .iter() - .map(|statement| statement.len()) - .max() - .unwrap_or_default() as i32, - ) -} diff --git a/src/scheme/kzg/accumulation.rs b/src/scheme/kzg/accumulation.rs deleted file mode 100644 index 6931639d..00000000 --- a/src/scheme/kzg/accumulation.rs +++ /dev/null @@ -1,171 +0,0 @@ -use crate::{ - loader::Loader, - protocol::Protocol, - scheme::kzg::msm::MSM, - util::{Curve, Transcript}, - Error, -}; -use std::ops::{Add, AddAssign, Mul, MulAssign}; - -pub mod plonk; -pub mod shplonk; - -pub trait AccumulationScheme -where - C: Curve, - L: Loader, - T: Transcript, - S: AccumulationStrategy, -{ - type Proof; - - fn accumulate( - protocol: &Protocol, - loader: &L, - statements: Vec>, - transcript: &mut T, - strategy: &mut S, - ) -> Result; -} - -pub trait AccumulationStrategy -where - C: Curve, - L: Loader, - T: Transcript, -{ - type Output; - - fn extract_accumulator( - &self, - _: &Protocol, - _: &L, - _: &mut T, - _: &[Vec], - ) -> Option> { - None - } - - fn process( - &mut self, - loader: &L, - transcript: &mut T, - proof: P, - accumulator: Accumulator, - ) -> Result; -} - -#[derive(Clone, Debug)] -pub struct Accumulator -where - C: Curve, - L: Loader, -{ - lhs: MSM, - rhs: MSM, -} - -impl Accumulator -where - C: Curve, - L: Loader, -{ - pub fn new(lhs: MSM, rhs: MSM) -> Self { - Self { lhs, rhs } - } - - pub fn scale(&mut self, scalar: &L::LoadedScalar) { - self.lhs *= scalar; - self.rhs *= scalar; - } - - pub fn extend(&mut self, other: Self) { - self.lhs += other.lhs; - self.rhs += other.rhs; - } - - pub fn evaluate(self, g1: C) -> (L::LoadedEcPoint, L::LoadedEcPoint) { - (self.lhs.evaluate(g1), self.rhs.evaluate(g1)) - } - - pub fn random_linear_combine( - scaled_accumulators: impl IntoIterator, - ) -> Self { - scaled_accumulators - .into_iter() - .map(|(scalar, accumulator)| accumulator * &scalar) - .reduce(|acc, scaled_accumulator| acc + scaled_accumulator) - .unwrap_or_default() - } -} - -impl Default for Accumulator -where - C: Curve, - L: Loader, -{ - fn default() -> Self { - Self { - lhs: MSM::default(), - rhs: MSM::default(), - } - } -} - -impl Add for Accumulator -where - C: Curve, - L: Loader, -{ - type Output = Self; - - fn add(mut self, rhs: Self) -> Self::Output { - self.extend(rhs); - self - } -} - -impl AddAssign for Accumulator -where - C: Curve, - L: Loader, -{ - fn add_assign(&mut self, rhs: Self) { - self.extend(rhs); - } -} - -impl Mul<&L::LoadedScalar> for Accumulator -where - C: Curve, - L: Loader, -{ - type Output = Self; - - fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { - self.scale(rhs); - self - } -} - -impl MulAssign<&L::LoadedScalar> for Accumulator -where - C: Curve, - L: Loader, -{ - fn mul_assign(&mut self, rhs: &L::LoadedScalar) { - self.scale(rhs); - } -} - -pub struct SameCurveAccumulation, const LIMBS: usize, const BITS: usize> { - pub accumulator: Option>, -} - -impl, const LIMBS: usize, const BITS: usize> Default - for SameCurveAccumulation -{ - fn default() -> Self { - Self { accumulator: None } - } -} diff --git a/src/scheme/kzg/accumulation/plonk.rs b/src/scheme/kzg/accumulation/plonk.rs deleted file mode 100644 index 10c91eed..00000000 --- a/src/scheme/kzg/accumulation/plonk.rs +++ /dev/null @@ -1,373 +0,0 @@ -use crate::{ - loader::{LoadedScalar, Loader}, - protocol::Protocol, - scheme::kzg::{ - accumulation::{AccumulationScheme, AccumulationStrategy, Accumulator}, - cost::{Cost, CostEstimation}, - langranges, - msm::MSM, - }, - util::{ - CommonPolynomial, CommonPolynomialEvaluation, Curve, Expression, Field, Itertools, Query, - Rotation, TranscriptRead, - }, - Error, -}; -use std::{collections::HashMap, iter}; - -#[derive(Default)] -pub struct PlonkAccumulationScheme; - -impl AccumulationScheme for PlonkAccumulationScheme -where - C: Curve, - L: Loader, - T: TranscriptRead, - S: AccumulationStrategy>, -{ - type Proof = PlonkProof; - - fn accumulate( - protocol: &Protocol, - loader: &L, - statements: Vec>, - transcript: &mut T, - strategy: &mut S, - ) -> Result { - transcript.common_scalar(&loader.load_const(&protocol.transcript_initial_state))?; - - let proof = PlonkProof::read(protocol, statements, transcript)?; - let old_accumulator = - strategy.extract_accumulator(protocol, loader, transcript, &proof.statements); - - let common_poly_eval = { - let mut common_poly_eval = CommonPolynomialEvaluation::new( - &protocol.domain, - loader, - langranges(protocol, &proof.statements), - &proof.z, - ); - - L::LoadedScalar::batch_invert(common_poly_eval.denoms()); - - common_poly_eval - }; - - let commitments = proof.commitments(protocol, loader, &common_poly_eval); - let evaluations = proof.evaluations(protocol, loader, &common_poly_eval)?; - - let sets = rotation_sets(protocol); - let powers_of_u = &proof.u.powers(sets.len()); - let f = { - let powers_of_v = proof - .v - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); - sets.iter() - .map(|set| set.msm(&commitments, &evaluations, &powers_of_v)) - .zip(powers_of_u.iter()) - .map(|(msm, power_of_u)| msm * power_of_u) - .sum::>() - }; - let z_omegas = sets.iter().map(|set| { - loader.load_const( - &protocol - .domain - .rotate_scalar(C::Scalar::one(), set.rotation), - ) * &proof.z - }); - - let rhs = proof - .ws - .iter() - .zip(powers_of_u.iter()) - .map(|(w, power_of_u)| MSM::base(w.clone()) * power_of_u) - .collect_vec(); - let lhs = f + rhs - .iter() - .zip(z_omegas) - .map(|(uw, z_omega)| uw.clone() * &z_omega) - .sum(); - - let mut accumulator = Accumulator::new(lhs, rhs.into_iter().sum()); - if let Some(old_accumulator) = old_accumulator { - accumulator += old_accumulator; - } - strategy.process(loader, transcript, proof, accumulator) - } -} - -pub struct PlonkProof> { - statements: Vec>, - auxiliaries: Vec, - challenges: Vec, - alpha: L::LoadedScalar, - quotients: Vec, - z: L::LoadedScalar, - evaluations: Vec, - v: L::LoadedScalar, - ws: Vec, - u: L::LoadedScalar, -} - -impl> PlonkProof { - fn read>( - protocol: &Protocol, - statements: Vec>, - transcript: &mut T, - ) -> Result { - if protocol.num_statement - != statements - .iter() - .map(|statements| statements.len()) - .collect_vec() - { - return Err(Error::InvalidInstances); - } - for statements in statements.iter() { - for statement in statements.iter() { - transcript.common_scalar(statement)?; - } - } - - let (auxiliaries, challenges) = { - let (auxiliaries, challenges) = protocol - .num_auxiliary - .iter() - .zip(protocol.num_challenge.iter()) - .map(|(&n, &m)| { - Ok(( - transcript.read_n_ec_points(n)?, - transcript.squeeze_n_challenges(m), - )) - }) - .collect::, Error>>()? - .into_iter() - .unzip::<_, _, Vec<_>, Vec<_>>(); - - ( - auxiliaries.into_iter().flatten().collect_vec(), - challenges.into_iter().flatten().collect_vec(), - ) - }; - - let alpha = transcript.squeeze_challenge(); - let quotients = { - let max_degree = protocol - .relations - .iter() - .map(Expression::degree) - .max() - .unwrap(); - transcript.read_n_ec_points(max_degree - 1)? - }; - - let z = transcript.squeeze_challenge(); - let evaluations = transcript.read_n_scalars(protocol.evaluations.len())?; - - let v = transcript.squeeze_challenge(); - let ws = transcript.read_n_ec_points(rotation_sets(protocol).len())?; - let u = transcript.squeeze_challenge(); - - Ok(Self { - statements, - auxiliaries, - challenges, - alpha, - quotients, - z, - evaluations, - v, - ws, - u, - }) - } - - fn commitments( - &self, - protocol: &Protocol, - loader: &L, - common_poly_eval: &CommonPolynomialEvaluation, - ) -> HashMap> { - iter::empty() - .chain( - protocol - .preprocessed - .iter() - .map(|value| MSM::base(loader.ec_point_load_const(value))) - .enumerate(), - ) - .chain({ - let auxiliary_offset = protocol.preprocessed.len() + protocol.num_statement.len(); - self.auxiliaries - .iter() - .cloned() - .enumerate() - .map(move |(i, auxiliary)| (auxiliary_offset + i, MSM::base(auxiliary))) - }) - .chain(iter::once(( - protocol.vanishing_poly(), - common_poly_eval - .zn() - .powers(self.quotients.len()) - .into_iter() - .zip(self.quotients.iter().cloned().map(MSM::base)) - .map(|(coeff, piece)| piece * &coeff) - .sum(), - ))) - .collect() - } - - fn evaluations( - &self, - protocol: &Protocol, - loader: &L, - common_poly_eval: &CommonPolynomialEvaluation, - ) -> Result, Error> { - let statement_evaluations = self.statements.iter().map(|statements| { - L::LoadedScalar::sum( - &statements - .iter() - .enumerate() - .map(|(i, statement)| { - common_poly_eval.get(CommonPolynomial::Lagrange(i as i32)) * statement - }) - .collect_vec(), - ) - }); - let mut evaluations = HashMap::::from_iter( - iter::empty() - .chain( - statement_evaluations - .into_iter() - .enumerate() - .map(|(i, evaluation)| { - ( - Query { - poly: protocol.preprocessed.len() + i, - rotation: Rotation::cur(), - }, - evaluation, - ) - }), - ) - .chain( - protocol - .evaluations - .iter() - .cloned() - .zip(self.evaluations.iter().cloned()), - ), - ); - - let powers_of_alpha = self.alpha.powers(protocol.relations.len()); - let quotient_evaluation = L::LoadedScalar::sum( - &powers_of_alpha - .into_iter() - .rev() - .zip(protocol.relations.iter()) - .map(|(power_of_alpha, relation)| { - relation - .evaluate( - &|scalar| Ok(loader.load_const(&scalar)), - &|poly| Ok(common_poly_eval.get(poly)), - &|index| { - evaluations - .get(&index) - .cloned() - .ok_or(Error::MissingQuery(index)) - }, - &|index| { - self.challenges - .get(index) - .cloned() - .ok_or(Error::MissingChallenge(index)) - }, - &|a| a.map(|a| -a), - &|a, b| a.and_then(|a| Ok(a + b?)), - &|a, b| a.and_then(|a| Ok(a * b?)), - &|a, scalar| a.map(|a| a * loader.load_const(&scalar)), - ) - .map(|evaluation| power_of_alpha * evaluation) - }) - .collect::, Error>>()?, - ) * &common_poly_eval.zn_minus_one_inv(); - - evaluations.insert( - Query { - poly: protocol.vanishing_poly(), - rotation: Rotation::cur(), - }, - quotient_evaluation, - ); - - Ok(evaluations) - } -} - -struct RotationSet { - rotation: Rotation, - polys: Vec, -} - -impl RotationSet { - fn msm>( - &self, - commitments: &HashMap>, - evaluations: &HashMap, - powers_of_v: &[L::LoadedScalar], - ) -> MSM { - self.polys - .iter() - .map(|poly| { - let commitment = commitments.get(poly).unwrap().clone(); - let evalaution = evaluations - .get(&Query::new(*poly, self.rotation)) - .unwrap() - .clone(); - commitment - MSM::scalar(evalaution) - }) - .zip(powers_of_v.iter()) - .map(|(msm, power_of_v)| msm * power_of_v) - .sum() - } -} - -fn rotation_sets(protocol: &Protocol) -> Vec { - protocol.queries.iter().fold(Vec::new(), |mut sets, query| { - if let Some(pos) = sets.iter().position(|set| set.rotation == query.rotation) { - sets[pos].polys.push(query.poly) - } else { - sets.push(RotationSet { - rotation: query.rotation, - polys: vec![query.poly], - }) - } - sets - }) -} - -impl CostEstimation for PlonkAccumulationScheme { - fn estimate_cost(protocol: &Protocol) -> Cost { - let num_quotient = protocol - .relations - .iter() - .map(Expression::degree) - .max() - .unwrap() - - 1; - let num_w = rotation_sets(protocol).len(); - let num_accumulator = protocol - .accumulator_indices - .as_ref() - .map(|accumulator_indices| accumulator_indices.len()) - .unwrap_or_default(); - - let num_statement = protocol.num_statement.iter().sum(); - let num_commitment = protocol.num_auxiliary.iter().sum::() + num_quotient + num_w; - let num_evaluation = protocol.evaluations.len(); - let num_msm = - protocol.preprocessed.len() + num_commitment + 1 + num_w + 2 * num_accumulator; - - Cost::new(num_statement, num_commitment, num_evaluation, num_msm) - } -} diff --git a/src/scheme/kzg/accumulation/shplonk.rs b/src/scheme/kzg/accumulation/shplonk.rs deleted file mode 100644 index e9c31aa8..00000000 --- a/src/scheme/kzg/accumulation/shplonk.rs +++ /dev/null @@ -1,593 +0,0 @@ -use crate::{ - loader::{LoadedScalar, Loader}, - protocol::Protocol, - scheme::kzg::{ - accumulation::{AccumulationScheme, AccumulationStrategy, Accumulator}, - cost::{Cost, CostEstimation}, - langranges, - msm::MSM, - }, - util::{ - CommonPolynomial, CommonPolynomialEvaluation, Curve, Domain, Expression, Field, Fraction, - Itertools, Query, Rotation, TranscriptRead, - }, - Error, -}; -use std::{ - collections::{BTreeSet, HashMap}, - iter, -}; - -#[derive(Default)] -pub struct ShplonkAccumulationScheme; - -impl AccumulationScheme for ShplonkAccumulationScheme -where - C: Curve, - L: Loader, - T: TranscriptRead, - S: AccumulationStrategy>, -{ - type Proof = ShplonkProof; - - fn accumulate( - protocol: &Protocol, - loader: &L, - statements: Vec>, - transcript: &mut T, - strategy: &mut S, - ) -> Result { - transcript.common_scalar(&loader.load_const(&protocol.transcript_initial_state))?; - - let proof = ShplonkProof::read(protocol, statements, transcript)?; - let old_accumulator = - strategy.extract_accumulator(protocol, loader, transcript, &proof.statements); - - let (common_poly_eval, sets) = { - let mut common_poly_eval = CommonPolynomialEvaluation::new( - &protocol.domain, - loader, - langranges(protocol, &proof.statements), - &proof.z, - ); - let mut sets = intermediate_sets(protocol, loader, &proof.z, &proof.z_prime); - - L::LoadedScalar::batch_invert( - iter::empty() - .chain(common_poly_eval.denoms()) - .chain(sets.iter_mut().flat_map(IntermediateSet::denoms)), - ); - L::LoadedScalar::batch_invert(sets.iter_mut().flat_map(IntermediateSet::denoms)); - - (common_poly_eval, sets) - }; - - let commitments = proof.commitments(protocol, loader, &common_poly_eval); - let evaluations = proof.evaluations(protocol, loader, &common_poly_eval)?; - - let f = { - let powers_of_mu = proof - .mu - .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); - let msms = sets - .iter() - .map(|set| set.msm(&commitments, &evaluations, &powers_of_mu)); - - msms.zip(proof.gamma.powers(sets.len()).into_iter()) - .map(|(msm, power_of_gamma)| msm * &power_of_gamma) - .sum::>() - - MSM::base(proof.w.clone()) * &sets[0].z_s - }; - - let rhs = MSM::base(proof.w_prime.clone()); - let lhs = f + rhs.clone() * &proof.z_prime; - - let mut accumulator = Accumulator::new(lhs, rhs); - if let Some(old_accumulator) = old_accumulator { - accumulator += old_accumulator; - } - strategy.process(loader, transcript, proof, accumulator) - } -} - -pub struct ShplonkProof> { - statements: Vec>, - auxiliaries: Vec, - challenges: Vec, - alpha: L::LoadedScalar, - quotients: Vec, - z: L::LoadedScalar, - evaluations: Vec, - mu: L::LoadedScalar, - gamma: L::LoadedScalar, - w: L::LoadedEcPoint, - z_prime: L::LoadedScalar, - w_prime: L::LoadedEcPoint, -} - -impl> ShplonkProof { - fn read>( - protocol: &Protocol, - statements: Vec>, - transcript: &mut T, - ) -> Result { - if protocol.num_statement - != statements - .iter() - .map(|statements| statements.len()) - .collect_vec() - { - return Err(Error::InvalidInstances); - } - for statements in statements.iter() { - for statement in statements.iter() { - transcript.common_scalar(statement)?; - } - } - - let (auxiliaries, challenges) = { - let (auxiliaries, challenges) = protocol - .num_auxiliary - .iter() - .zip(protocol.num_challenge.iter()) - .map(|(&n, &m)| { - Ok(( - transcript.read_n_ec_points(n)?, - transcript.squeeze_n_challenges(m), - )) - }) - .collect::, Error>>()? - .into_iter() - .unzip::<_, _, Vec<_>, Vec<_>>(); - - ( - auxiliaries.into_iter().flatten().collect_vec(), - challenges.into_iter().flatten().collect_vec(), - ) - }; - - let alpha = transcript.squeeze_challenge(); - let quotients = { - let max_degree = protocol - .relations - .iter() - .map(Expression::degree) - .max() - .unwrap(); - transcript.read_n_ec_points(max_degree - 1)? - }; - - let z = transcript.squeeze_challenge(); - let evaluations = transcript.read_n_scalars(protocol.evaluations.len())?; - - let mu = transcript.squeeze_challenge(); - let gamma = transcript.squeeze_challenge(); - let w = transcript.read_ec_point()?; - let z_prime = transcript.squeeze_challenge(); - let w_prime = transcript.read_ec_point()?; - - Ok(Self { - statements, - auxiliaries, - challenges, - alpha, - quotients, - z, - evaluations, - mu, - gamma, - w, - z_prime, - w_prime, - }) - } - - fn commitments( - &self, - protocol: &Protocol, - loader: &L, - common_poly_eval: &CommonPolynomialEvaluation, - ) -> HashMap> { - iter::empty() - .chain( - protocol - .preprocessed - .iter() - .map(|value| MSM::base(loader.ec_point_load_const(value))) - .enumerate(), - ) - .chain({ - let auxiliary_offset = protocol.preprocessed.len() + protocol.num_statement.len(); - self.auxiliaries - .iter() - .cloned() - .enumerate() - .map(move |(i, auxiliary)| (auxiliary_offset + i, MSM::base(auxiliary))) - }) - .chain(iter::once(( - protocol.vanishing_poly(), - common_poly_eval - .zn() - .powers(self.quotients.len()) - .into_iter() - .zip(self.quotients.iter().cloned().map(MSM::base)) - .map(|(coeff, piece)| piece * &coeff) - .sum(), - ))) - .collect() - } - - fn evaluations( - &self, - protocol: &Protocol, - loader: &L, - common_poly_eval: &CommonPolynomialEvaluation, - ) -> Result, Error> { - let statement_evaluations = self.statements.iter().map(|statements| { - L::LoadedScalar::sum( - &statements - .iter() - .enumerate() - .map(|(i, statement)| { - statement.clone() - * common_poly_eval.get(CommonPolynomial::Lagrange(i as i32)) - }) - .collect_vec(), - ) - }); - let mut evaluations = HashMap::::from_iter( - iter::empty() - .chain( - statement_evaluations - .into_iter() - .enumerate() - .map(|(i, evaluation)| { - ( - Query { - poly: protocol.preprocessed.len() + i, - rotation: Rotation::cur(), - }, - evaluation, - ) - }), - ) - .chain( - protocol - .evaluations - .iter() - .cloned() - .zip(self.evaluations.iter().cloned()), - ), - ); - - let powers_of_alpha = self.alpha.powers(protocol.relations.len()); - let quotient_evaluation = L::LoadedScalar::sum( - &powers_of_alpha - .into_iter() - .rev() - .zip(protocol.relations.iter()) - .map(|(power_of_alpha, relation)| { - relation - .evaluate( - &|scalar| Ok(loader.load_const(&scalar)), - &|poly| Ok(common_poly_eval.get(poly)), - &|index| { - evaluations - .get(&index) - .cloned() - .ok_or(Error::MissingQuery(index)) - }, - &|index| { - self.challenges - .get(index) - .cloned() - .ok_or(Error::MissingChallenge(index)) - }, - &|a| a.map(|a| -a), - &|a, b| a.and_then(|a| Ok(a + b?)), - &|a, b| a.and_then(|a| Ok(a * b?)), - &|a, scalar| a.map(|a| a * loader.load_const(&scalar)), - ) - .map(|evaluation| power_of_alpha * evaluation) - }) - .collect::, Error>>()?, - ) * &common_poly_eval.zn_minus_one_inv(); - - evaluations.insert( - Query { - poly: protocol.vanishing_poly(), - rotation: Rotation::cur(), - }, - quotient_evaluation, - ); - - Ok(evaluations) - } -} - -struct IntermediateSet> { - rotations: Vec, - polys: Vec, - z_s: L::LoadedScalar, - evaluation_coeffs: Vec>, - commitment_coeff: Option>, - remainder_coeff: Option>, -} - -impl> IntermediateSet { - fn new( - domain: &Domain, - loader: &L, - rotations: Vec, - powers_of_z: &[L::LoadedScalar], - z_prime: &L::LoadedScalar, - z_prime_minus_z_omega_i: &HashMap, - z_s_1: &Option, - ) -> Self { - let omegas = rotations - .iter() - .map(|rotation| domain.rotate_scalar(C::Scalar::one(), *rotation)) - .collect_vec(); - - let normalized_ell_primes = omegas - .iter() - .enumerate() - .map(|(j, omega_j)| { - omegas - .iter() - .enumerate() - .filter(|&(i, _)| i != j) - .fold(C::Scalar::one(), |acc, (_, omega_i)| { - acc * (*omega_j - omega_i) - }) - }) - .collect_vec(); - - let z = &powers_of_z[1].clone(); - let z_pow_k_minus_one = { - let k_minus_one = rotations.len() - 1; - powers_of_z.iter().enumerate().skip(1).fold( - loader.load_one(), - |acc, (i, power_of_z)| { - if k_minus_one & (1 << i) == 1 { - acc * power_of_z - } else { - acc - } - }, - ) - }; - - let barycentric_weights = omegas - .iter() - .zip(normalized_ell_primes.iter()) - .map(|(omega, normalized_ell_prime)| { - L::LoadedScalar::sum_products_with_coeff_and_constant( - &[ - ( - *normalized_ell_prime, - z_pow_k_minus_one.clone(), - z_prime.clone(), - ), - ( - -(*normalized_ell_prime * omega), - z_pow_k_minus_one.clone(), - z.clone(), - ), - ], - &C::Scalar::zero(), - ) - }) - .map(Fraction::one_over) - .collect_vec(); - - let z_s = rotations - .iter() - .map(|rotation| z_prime_minus_z_omega_i.get(rotation).unwrap().clone()) - .reduce(|acc, z_prime_minus_z_omega_i| acc * z_prime_minus_z_omega_i) - .unwrap(); - let z_s_1_over_z_s = z_s_1.clone().map(|z_s_1| Fraction::new(z_s_1, z_s.clone())); - - Self { - rotations, - polys: Vec::new(), - z_s, - evaluation_coeffs: barycentric_weights, - commitment_coeff: z_s_1_over_z_s, - remainder_coeff: None, - } - } - - fn denoms(&mut self) -> impl IntoIterator { - if self.evaluation_coeffs.first().unwrap().denom().is_some() { - self.evaluation_coeffs - .iter_mut() - .chain(self.commitment_coeff.as_mut()) - .filter_map(Fraction::denom_mut) - .collect_vec() - } else if self.remainder_coeff.is_none() { - let barycentric_weights_sum = L::LoadedScalar::sum( - &self - .evaluation_coeffs - .iter() - .map(Fraction::evaluate) - .collect_vec(), - ); - self.remainder_coeff = Some(match self.commitment_coeff.clone() { - Some(coeff) => Fraction::new(coeff.evaluate(), barycentric_weights_sum), - None => Fraction::one_over(barycentric_weights_sum), - }); - vec![self.remainder_coeff.as_mut().unwrap().denom_mut().unwrap()] - } else { - unreachable!() - } - } - - fn msm( - &self, - commitments: &HashMap>, - evaluations: &HashMap, - powers_of_mu: &[L::LoadedScalar], - ) -> MSM { - self.polys - .iter() - .zip(powers_of_mu.iter()) - .map(|(poly, power_of_mu)| { - let commitment = self - .commitment_coeff - .as_ref() - .map(|commitment_coeff| { - commitments.get(poly).unwrap().clone() * &commitment_coeff.evaluate() - }) - .unwrap_or_else(|| commitments.get(poly).unwrap().clone()); - let remainder = self.remainder_coeff.as_ref().unwrap().evaluate() - * L::LoadedScalar::sum( - &self - .rotations - .iter() - .zip(self.evaluation_coeffs.iter()) - .map(|(rotation, coeff)| { - coeff.evaluate() - * evaluations - .get(&Query { - poly: *poly, - rotation: *rotation, - }) - .unwrap() - }) - .collect_vec(), - ); - (commitment - MSM::scalar(remainder)) * power_of_mu - }) - .sum() - } -} - -fn intermediate_sets>( - protocol: &Protocol, - loader: &L, - z: &L::LoadedScalar, - z_prime: &L::LoadedScalar, -) -> Vec> { - let rotations_sets = rotations_sets(protocol); - let superset = rotations_sets - .iter() - .flat_map(|set| set.rotations.clone()) - .sorted() - .dedup(); - - let size = 2.max( - (rotations_sets - .iter() - .map(|set| set.rotations.len()) - .max() - .unwrap() - - 1) - .next_power_of_two() - .log2() as usize - + 1, - ); - let powers_of_z = z.powers(size); - let z_prime_minus_z_omega_i = HashMap::from_iter( - superset - .map(|rotation| { - ( - rotation, - loader.load_const(&protocol.domain.rotate_scalar(C::Scalar::one(), rotation)), - ) - }) - .map(|(rotation, omega)| (rotation, z_prime.clone() - z.clone() * omega)), - ); - - let mut z_s_1 = None; - rotations_sets - .into_iter() - .map(|set| { - let intermetidate_set = IntermediateSet { - polys: set.polys, - ..IntermediateSet::new( - &protocol.domain, - loader, - set.rotations, - &powers_of_z, - z_prime, - &z_prime_minus_z_omega_i, - &z_s_1, - ) - }; - if z_s_1.is_none() { - z_s_1 = Some(intermetidate_set.z_s.clone()); - }; - intermetidate_set - }) - .collect() -} - -struct RotationsSet { - rotations: Vec, - polys: Vec, -} - -fn rotations_sets(protocol: &Protocol) -> Vec { - let poly_rotations = protocol.queries.iter().fold( - Vec::<(usize, Vec)>::new(), - |mut poly_rotations, query| { - if let Some(pos) = poly_rotations - .iter() - .position(|(poly, _)| *poly == query.poly) - { - let (_, rotations) = &mut poly_rotations[pos]; - if !rotations.contains(&query.rotation) { - rotations.push(query.rotation); - } - } else { - poly_rotations.push((query.poly, vec![query.rotation])); - } - poly_rotations - }, - ); - - poly_rotations - .into_iter() - .fold(Vec::::new(), |mut sets, (poly, rotations)| { - if let Some(pos) = sets.iter().position(|set| { - BTreeSet::from_iter(set.rotations.iter()) == BTreeSet::from_iter(rotations.iter()) - }) { - let set = &mut sets[pos]; - if !set.polys.contains(&poly) { - set.polys.push(poly); - } - } else { - let set = RotationsSet { - rotations, - polys: vec![poly], - }; - sets.push(set); - } - sets - }) -} - -impl CostEstimation for ShplonkAccumulationScheme { - fn estimate_cost(protocol: &Protocol) -> Cost { - let num_quotient = protocol - .relations - .iter() - .map(Expression::degree) - .max() - .unwrap() - - 1; - let num_accumulator = protocol - .accumulator_indices - .as_ref() - .map(|accumulator_indices| accumulator_indices.len()) - .unwrap_or_default(); - - let num_statement = protocol.num_statement.iter().sum(); - let num_commitment = protocol.num_auxiliary.iter().sum::() + num_quotient + 2; - let num_evaluation = protocol.evaluations.len(); - let num_msm = protocol.preprocessed.len() + num_commitment + 3 + 2 * num_accumulator; - - Cost::new(num_statement, num_commitment, num_evaluation, num_msm) - } -} diff --git a/src/scheme/kzg/cost.rs b/src/scheme/kzg/cost.rs deleted file mode 100644 index f83f7dd1..00000000 --- a/src/scheme/kzg/cost.rs +++ /dev/null @@ -1,29 +0,0 @@ -use crate::{protocol::Protocol, util::Curve}; - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct Cost { - pub num_statement: usize, - pub num_commitment: usize, - pub num_evaluation: usize, - pub num_msm: usize, -} - -impl Cost { - pub fn new( - num_statement: usize, - num_commitment: usize, - num_evaluation: usize, - num_msm: usize, - ) -> Self { - Self { - num_statement, - num_commitment, - num_evaluation, - num_msm, - } - } -} - -pub trait CostEstimation { - fn estimate_cost(protocol: &Protocol) -> Cost; -} diff --git a/src/scheme/kzg/msm.rs b/src/scheme/kzg/msm.rs deleted file mode 100644 index db5e4ce2..00000000 --- a/src/scheme/kzg/msm.rs +++ /dev/null @@ -1,149 +0,0 @@ -use crate::{ - loader::{LoadedEcPoint, Loader}, - util::Curve, -}; -use std::{ - default::Default, - iter::{self, Sum}, - ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, -}; - -#[derive(Clone, Debug)] -pub struct MSM> { - pub scalar: Option, - bases: Vec, - scalars: Vec, -} - -impl> Default for MSM { - fn default() -> Self { - Self { - scalar: None, - scalars: Vec::new(), - bases: Vec::new(), - } - } -} - -impl> MSM { - pub fn scalar(scalar: L::LoadedScalar) -> Self { - MSM { - scalar: Some(scalar), - ..Default::default() - } - } - - pub fn base(base: L::LoadedEcPoint) -> Self { - let one = base.loader().load_one(); - MSM { - scalars: vec![one], - bases: vec![base], - ..Default::default() - } - } - - pub fn evaluate(self, gen: C) -> L::LoadedEcPoint { - let gen = self - .bases - .first() - .unwrap() - .loader() - .ec_point_load_const(&gen); - L::LoadedEcPoint::multi_scalar_multiplication( - iter::empty() - .chain(self.scalar.map(|scalar| (scalar, gen))) - .chain(self.scalars.into_iter().zip(self.bases.into_iter())), - ) - } - - pub fn scale(&mut self, factor: &L::LoadedScalar) { - if let Some(scalar) = self.scalar.as_mut() { - *scalar *= factor; - } - for scalar in self.scalars.iter_mut() { - *scalar *= factor - } - } - - pub fn push(&mut self, scalar: L::LoadedScalar, base: L::LoadedEcPoint) { - if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) { - self.scalars[pos] += scalar; - } else { - self.scalars.push(scalar); - self.bases.push(base); - } - } - - pub fn extend(&mut self, mut other: Self) { - match (self.scalar.as_mut(), other.scalar.as_ref()) { - (Some(lhs), Some(rhs)) => *lhs += rhs, - (None, Some(_)) => self.scalar = other.scalar.take(), - _ => {} - }; - for (scalar, base) in other.scalars.into_iter().zip(other.bases) { - self.push(scalar, base); - } - } -} - -impl> Add> for MSM { - type Output = MSM; - - fn add(mut self, rhs: MSM) -> Self::Output { - self.extend(rhs); - self - } -} - -impl> AddAssign> for MSM { - fn add_assign(&mut self, rhs: MSM) { - self.extend(rhs); - } -} - -impl> Sub> for MSM { - type Output = MSM; - - fn sub(mut self, rhs: MSM) -> Self::Output { - self.extend(-rhs); - self - } -} - -impl> SubAssign> for MSM { - fn sub_assign(&mut self, rhs: MSM) { - self.extend(-rhs); - } -} - -impl> Mul<&L::LoadedScalar> for MSM { - type Output = MSM; - - fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { - self.scale(rhs); - self - } -} - -impl> MulAssign<&L::LoadedScalar> for MSM { - fn mul_assign(&mut self, rhs: &L::LoadedScalar) { - self.scale(rhs); - } -} - -impl> Neg for MSM { - type Output = MSM; - fn neg(mut self) -> MSM { - self.scalar = self.scalar.map(|scalar| -scalar); - for scalar in self.scalars.iter_mut() { - *scalar = -scalar.clone(); - } - self - } -} - -impl> Sum for MSM { - fn sum>(iter: I) -> Self { - iter.reduce(|acc, item| acc + item).unwrap_or_default() - } -} diff --git a/src/system.rs b/src/system.rs new file mode 100644 index 00000000..5d5aa99c --- /dev/null +++ b/src/system.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "system_halo2")] +pub mod halo2; diff --git a/src/protocol/halo2.rs b/src/system/halo2.rs similarity index 76% rename from src/protocol/halo2.rs rename to src/system/halo2.rs index 6c30bf5d..bf3e4091 100644 --- a/src/protocol/halo2.rs +++ b/src/system/halo2.rs @@ -1,40 +1,99 @@ use crate::{ - protocol::Protocol, - util::{CommonPolynomial, Domain, Expression, Itertools, Query, Rotation}, + util::{ + arithmetic::{root_of_unity, CurveAffine, Domain, FieldExt, Rotation}, + protocol::{ + CommonPolynomial, Expression, InstanceCommittingKey, Query, QuotientPolynomial, + }, + Itertools, + }, + Protocol, }; use halo2_proofs::{ - arithmetic::{CurveAffine, CurveExt, FieldExt}, plonk::{self, Any, ConstraintSystem, FirstPhase, SecondPhase, ThirdPhase, VerifyingKey}, - poly, + poly::{self, commitment::Params}, transcript::{EncodedChallenge, Transcript}, }; -use std::{io, iter}; +use num_integer::Integer; +use std::{io, iter, mem::size_of}; -mod util; +pub mod transcript; #[cfg(test)] -mod test; +pub(crate) mod test; +#[derive(Clone, Debug, Default)] pub struct Config { - zk: bool, - query_instance: bool, - num_instance: Vec, - num_proof: usize, - accumulator_indices: Option>, + pub zk: bool, + pub query_instance: bool, + pub num_proof: usize, + pub num_instance: Vec, + pub accumulator_indices: Option>, } -pub fn compile(vk: &VerifyingKey, config: Config) -> Protocol { +impl Config { + pub fn kzg() -> Self { + Self { + zk: true, + query_instance: false, + num_proof: 1, + ..Default::default() + } + } + + pub fn ipa() -> Self { + Self { + zk: true, + query_instance: true, + num_proof: 1, + ..Default::default() + } + } + + pub fn set_zk(mut self, zk: bool) -> Self { + self.zk = zk; + self + } + + pub fn set_query_instance(mut self, query_instance: bool) -> Self { + self.query_instance = query_instance; + self + } + + pub fn with_num_proof(mut self, num_proof: usize) -> Self { + assert!(num_proof > 0); + self.num_proof = num_proof; + self + } + + pub fn with_num_instance(mut self, num_instance: Vec) -> Self { + self.num_instance = num_instance; + self + } + + pub fn with_accumulator_indices(mut self, accumulator_indices: Vec<(usize, usize)>) -> Self { + self.accumulator_indices = Some(accumulator_indices); + self + } +} + +pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( + params: &P, + vk: &VerifyingKey, + config: Config, +) -> Protocol { + assert_eq!(vk.get_domain().k(), params.k()); + let cs = vk.cs(); let Config { zk, - num_instance, query_instance, num_proof, + num_instance, accumulator_indices, } = config; - let k = vk.get_domain().empty_lagrange().len().log2(); - let domain = Domain::new(k as usize); + let k = params.k() as usize; + let domain = Domain::new(k, root_of_unity(k)); let preprocessed = vk .fixed_commitments() @@ -66,35 +125,39 @@ pub fn compile(vk: &VerifyingKey, config: Config) -> }) .chain(polynomials.fixed_queries()) .chain(polynomials.permutation_fixed_queries()) - .chain(iter::once(polynomials.vanishing_query())) + .chain(iter::once(polynomials.quotient_query())) .chain(polynomials.random_query()) .collect(); - let relations = (0..num_proof) - .flat_map(|t| { - iter::empty() - .chain(polynomials.gate_relations(t)) - .chain(polynomials.permutation_relations(t)) - .chain(polynomials.lookup_relations(t)) - }) - .collect(); - let transcript_initial_state = transcript_initial_state::(vk); + let instance_committing_key = query_instance.then(|| { + instance_committing_key( + params, + polynomials + .num_instance() + .into_iter() + .max() + .unwrap_or_default(), + ) + }); + let accumulator_indices = accumulator_indices - .map(|accumulator_indices| polynomials.accumulator_indices(accumulator_indices)); + .map(|accumulator_indices| polynomials.accumulator_indices(accumulator_indices)) + .unwrap_or_default(); Protocol { - zk: config.zk, domain, preprocessed, - num_statement: polynomials.num_statement(), - num_auxiliary: polynomials.num_auxiliary(), + num_instance: polynomials.num_instance(), + num_witness: polynomials.num_witness(), num_challenge: polynomials.num_challenge(), evaluations, queries, - relations, - transcript_initial_state, + quotient: polynomials.quotient(), + transcript_initial_state: Some(transcript_initial_state), + instance_committing_key, + linearization: None, accumulator_indices, } } @@ -131,11 +194,8 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { num_instance: Vec, num_proof: usize, ) -> Self { - let degree = if zk { - cs.degree::() - } else { - cs.degree::() - }; + // TODO: Re-enable optional-zk when it's merged in pse/halo2. + let degree = if zk { cs.degree() } else { unimplemented!() }; let permutation_chunk_size = if zk || cs.permutation().get_columns().len() >= degree { degree - 2 } else { @@ -155,7 +215,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { state[*phase as usize] += 1; Some(index) }) - .collect_vec(); + .collect::>(); (num, index) }; @@ -164,8 +224,6 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { assert_eq!(num_advice.iter().sum::(), cs.num_advice_columns()); assert_eq!(num_challenge.iter().sum::(), cs.num_challenges()); - assert_eq!(cs.num_instance_columns(), num_instance.len()); - Self { cs, zk, @@ -180,11 +238,10 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { challenge_index, num_lookup_permuted: 2 * cs.lookups().len(), permutation_chunk_size, - num_permutation_z: cs - .permutation() - .get_columns() - .len() - .div_ceil(permutation_chunk_size), + num_permutation_z: Integer::div_ceil( + &cs.permutation().get_columns().len(), + &permutation_chunk_size, + ), num_lookup_z: cs.lookups().len(), } } @@ -193,14 +250,14 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { self.num_fixed + self.num_permutation_fixed } - fn num_statement(&self) -> Vec { + fn num_instance(&self) -> Vec { iter::repeat(self.num_instance.clone()) .take(self.num_proof) .flatten() .collect() } - fn num_auxiliary(&self) -> Vec { + fn num_witness(&self) -> Vec { iter::empty() .chain( self.num_advice @@ -222,7 +279,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .chain(num_challenge) .chain([ 2, // beta, gamma - 0, + 1, // alpha ]) .collect() } @@ -231,14 +288,14 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { self.num_preprocessed() } - fn auxiliary_offset(&self) -> usize { - self.instance_offset() + self.num_statement().len() + fn witness_offset(&self) -> usize { + self.instance_offset() + self.num_instance().len() } - fn cs_auxiliary_offset(&self) -> usize { - self.auxiliary_offset() + fn cs_witness_offset(&self) -> usize { + self.witness_offset() + self - .num_auxiliary() + .num_witness() .iter() .take(self.num_advice.len()) .sum::() @@ -260,9 +317,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { * self.num_advice[..advice.phase() as usize] .iter() .sum::(); - self.auxiliary_offset() - + phase_offset - + t * self.num_advice[advice.phase() as usize] + self.witness_offset() + phase_offset + t * self.num_advice[advice.phase() as usize] } }; Query::new(offset + column_index, rotation.into()) @@ -270,14 +325,14 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { fn instance_queries(&'a self, t: usize) -> impl IntoIterator + 'a { self.query_instance - .then_some( + .then(|| { self.cs .instance_queries() .iter() .map(move |(column, rotation)| { self.query(*column.column_type(), column.index(), *rotation, t) - }), - ) + }) + }) .into_iter() .flatten() } @@ -305,7 +360,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { } fn permutation_poly(&'a self, t: usize, i: usize) -> usize { - let z_offset = self.cs_auxiliary_offset() + self.num_auxiliary()[self.num_advice.len()]; + let z_offset = self.cs_witness_offset() + self.num_witness()[self.num_advice.len()]; z_offset + t * self.num_permutation_z + i } @@ -346,9 +401,9 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { } fn lookup_poly(&'a self, t: usize, i: usize) -> (usize, usize, usize) { - let permuted_offset = self.cs_auxiliary_offset(); + let permuted_offset = self.cs_witness_offset(); let z_offset = permuted_offset - + self.num_auxiliary()[self.num_advice.len()] + + self.num_witness()[self.num_advice.len()] + self.num_proof * self.num_permutation_z; let z = z_offset + t * self.num_lookup_z + i; let permuted_input = permuted_offset + 2 * (t * self.num_lookup_z + i); @@ -382,18 +437,20 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { }) } - fn vanishing_query(&self) -> Query { + fn quotient_query(&self) -> Query { Query::new( - self.auxiliary_offset() + self.num_auxiliary().iter().sum::(), + self.witness_offset() + self.num_witness().iter().sum::(), 0, ) } fn random_query(&self) -> Option { - self.zk.then_some(Query::new( - self.auxiliary_offset() + self.num_auxiliary().iter().sum::() - 1, - 0, - )) + self.zk.then(|| { + Query::new( + self.witness_offset() + self.num_witness().iter().sum::() - 1, + 0, + ) + }) } fn convert(&self, expression: &plonk::Expression, t: usize) -> Expression { @@ -435,7 +492,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { ) } - fn gate_relations(&'a self, t: usize) -> impl IntoIterator> + 'a { + fn gate_constraints(&'a self, t: usize) -> impl IntoIterator> + 'a { self.cs.gates().iter().flat_map(move |gate| { gate.polynomials() .iter() @@ -444,7 +501,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { } fn rotation_last(&self) -> Rotation { - Rotation(-((self.cs.blinding_factors::() + 1) as i32)) + Rotation(-((self.cs.blinding_factors() + 1) as i32)) } fn l_last(&self) -> Expression { @@ -483,7 +540,11 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { Expression::Challenge(self.system_challenge_offset() + 2) } - fn permutation_relations(&'a self, t: usize) -> impl IntoIterator> + 'a { + fn alpha(&self) -> Expression { + Expression::Challenge(self.system_challenge_offset() + 3) + } + + fn permutation_constraints(&'a self, t: usize) -> impl IntoIterator> + 'a { let one = &Expression::Constant(F::one()); let l_0 = &Expression::::CommonPolynomial(CommonPolynomial::Lagrange(0)); let l_last = &self.l_last(); @@ -519,7 +580,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .chain(zs.first().map(|(z_0, _, _)| l_0 * (one - z_0))) .chain( zs.last() - .and_then(|(z_l, _, _)| self.zk.then_some(l_last * (z_l * z_l - z_l))), + .and_then(|(z_l, _, _)| self.zk.then(|| l_last * (z_l * z_l - z_l))), ) .chain(if self.zk { zs.iter() @@ -575,12 +636,11 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .collect_vec() } - fn lookup_relations(&'a self, t: usize) -> impl IntoIterator> + 'a { + fn lookup_constraints(&'a self, t: usize) -> impl IntoIterator> + 'a { let one = &Expression::Constant(F::one()); let l_0 = &Expression::::CommonPolynomial(CommonPolynomial::Lagrange(0)); let l_last = &self.l_last(); let l_active = &self.l_active(); - let theta = &self.theta(); let beta = &self.beta(); let gamma = &self.gamma(); @@ -598,15 +658,13 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .collect_vec(); let compress = |expressions: &'a [plonk::Expression]| { - expressions - .iter() - .rev() - .zip(iter::successors(Some(one.clone()), |power_of_theta| { - Some(power_of_theta * theta) - })) - .map(|(expression, power_of_theta)| power_of_theta * self.convert(expression, t)) - .reduce(|acc, expr| acc + expr) - .unwrap() + Expression::DistributePowers( + expressions + .iter() + .map(|expression| self.convert(expression, t)) + .collect(), + self.theta().into(), + ) }; self.cs @@ -619,7 +677,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { let table = compress(lookup.table_expressions()); iter::empty() .chain(Some(l_0 * (one - z))) - .chain(self.zk.then_some(l_last * (z * z - z))) + .chain(self.zk.then(|| l_last * (z * z - z))) .chain(Some(if self.zk { l_active * (z_w * (permuted_input + beta) * (permuted_table + gamma) @@ -628,7 +686,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { z_w * (permuted_input + beta) * (permuted_table + gamma) - z * (input + beta) * (table + gamma) })) - .chain(self.zk.then_some(l_0 * (permuted_input - permuted_table))) + .chain(self.zk.then(|| l_0 * (permuted_input - permuted_table))) .chain(Some(if self.zk { l_active * (permuted_input - permuted_table) @@ -642,6 +700,22 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .collect_vec() } + fn quotient(&self) -> QuotientPolynomial { + let constraints = (0..self.num_proof) + .flat_map(|t| { + iter::empty() + .chain(self.gate_constraints(t)) + .chain(self.permutation_constraints(t)) + .chain(self.lookup_constraints(t)) + }) + .collect_vec(); + let numerator = Expression::DistributePowers(constraints, self.alpha().into()); + QuotientPolynomial { + chunk_degree: 1, + numerator, + } + } + fn accumulator_indices( &self, accumulator_indices: Vec<(usize, usize)>, @@ -690,8 +764,47 @@ impl Transcript for MockTranscript } } -fn transcript_initial_state(vk: &VerifyingKey) -> C::ScalarExt { +fn transcript_initial_state(vk: &VerifyingKey) -> C::Scalar { let mut transcript = MockTranscript::default(); vk.hash_into(&mut transcript).unwrap(); transcript.0 } + +fn instance_committing_key<'a, C: CurveAffine, P: Params<'a, C>>( + params: &P, + len: usize, +) -> InstanceCommittingKey { + let buf = { + let mut buf = Vec::new(); + params.write(&mut buf).unwrap(); + buf + }; + + let repr = C::Repr::default(); + let repr_len = repr.as_ref().len(); + let offset = size_of::() + (1 << params.k()) * repr_len; + + let bases = (offset..) + .step_by(repr_len) + .map(|offset| { + let mut repr = C::Repr::default(); + repr.as_mut() + .copy_from_slice(&buf[offset..offset + repr_len]); + C::from_bytes(&repr).unwrap() + }) + .take(len) + .collect(); + + let w = { + let offset = size_of::() + (2 << params.k()) * repr_len; + let mut repr = C::Repr::default(); + repr.as_mut() + .copy_from_slice(&buf[offset..offset + repr_len]); + C::from_bytes(&repr).unwrap() + }; + + InstanceCommittingKey { + bases, + constant: Some(w), + } +} diff --git a/src/system/halo2/test.rs b/src/system/halo2/test.rs new file mode 100644 index 00000000..9cd4a2fc --- /dev/null +++ b/src/system/halo2/test.rs @@ -0,0 +1,221 @@ +use crate::util::arithmetic::CurveAffine; +use halo2_proofs::{ + dev::MockProver, + plonk::{create_proof, verify_proof, Circuit, ProvingKey}, + poly::{ + commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier}, + VerificationStrategy, + }, + transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use rand_chacha::rand_core::RngCore; +use std::{fs, io::Cursor}; + +mod circuit; +mod kzg; + +pub use circuit::{ + maingate::{MainGateWithRange, MainGateWithRangeConfig}, + standard::StandardPlonk, +}; + +pub fn read_or_create_srs<'a, C: CurveAffine, P: ParamsProver<'a, C>>( + dir: &str, + k: u32, + setup: impl Fn(u32) -> P, +) -> P { + let path = format!("{}/k-{}.srs", dir, k); + match fs::File::open(path.as_str()) { + Ok(mut file) => P::read(&mut file).unwrap(), + Err(_) => { + fs::create_dir_all(dir).unwrap(); + let params = setup(k); + params.write(&mut fs::File::create(path).unwrap()).unwrap(); + params + } + } +} + +pub fn create_proof_checked<'a, S, C, P, V, VS, TW, TR, EC, R>( + params: &'a S::ParamsProver, + pk: &ProvingKey, + circuits: &[C], + instances: &[&[&[S::Scalar]]], + mut rng: R, + finalize: impl Fn(Vec, VS::Output) -> Vec, +) -> Vec +where + S: CommitmentScheme, + S::ParamsVerifier: 'a, + C: Circuit, + P: Prover<'a, S>, + V: Verifier<'a, S>, + VS: VerificationStrategy<'a, S, V>, + TW: TranscriptWriterBuffer, S::Curve, EC>, + TR: TranscriptReadBuffer>, S::Curve, EC>, + EC: EncodedChallenge, + R: RngCore, +{ + for (circuit, instances) in circuits.iter().zip(instances.iter()) { + MockProver::run( + params.k(), + circuit, + instances.iter().map(|instance| instance.to_vec()).collect(), + ) + .unwrap() + .assert_satisfied(); + } + + let proof = { + let mut transcript = TW::init(Vec::new()); + create_proof::( + params, + pk, + circuits, + instances, + &mut rng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let output = { + let params = params.verifier_params(); + let strategy = VS::new(params); + let mut transcript = TR::init(Cursor::new(proof.clone())); + verify_proof(params, pk.get_vk(), strategy, instances, &mut transcript).unwrap() + }; + + finalize(proof, output) +} + +macro_rules! halo2_prepare { + ($dir:expr, $k:expr, $setup:expr, $config:expr, $create_circuit:expr) => {{ + use halo2_proofs::plonk::{keygen_pk, keygen_vk}; + use std::iter; + use $crate::{ + system::halo2::{compile, test::read_or_create_srs}, + util::{arithmetic::GroupEncoding, Itertools}, + }; + + let params = read_or_create_srs($dir, $k, $setup); + + let circuits = iter::repeat_with(|| $create_circuit) + .take($config.num_proof) + .collect_vec(); + + let pk = if $config.zk { + let vk = keygen_vk(¶ms, &circuits[0]).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuits[0]).unwrap(); + pk + } else { + // TODO: Re-enable optional-zk when it's merged in pse/halo2. + unimplemented!() + }; + + let num_instance = circuits[0] + .instances() + .iter() + .map(|instances| instances.len()) + .collect(); + let protocol = compile( + ¶ms, + pk.get_vk(), + $config.with_num_instance(num_instance), + ); + assert_eq!( + protocol.preprocessed.len(), + protocol + .preprocessed + .iter() + .map( + |ec_point| <[u8; 32]>::try_from(ec_point.to_bytes().as_ref().to_vec()).unwrap() + ) + .unique() + .count() + ); + + (params, pk, protocol, circuits) + }}; +} + +macro_rules! halo2_create_snark { + ( + $commitment_scheme:ty, + $prover:ty, + $verifier:ty, + $verification_strategy:ty, + $transcript_read:ty, + $transcript_write:ty, + $encoded_challenge:ty, + $finalize:expr, + $params:expr, + $pk:expr, + $protocol:expr, + $circuits:expr + ) => {{ + use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + use $crate::{ + loader::halo2::test::Snark, system::halo2::test::create_proof_checked, util::Itertools, + }; + + let instances = $circuits + .iter() + .map(|circuit| circuit.instances()) + .collect_vec(); + let proof = { + #[allow(clippy::needless_borrow)] + let instances = instances + .iter() + .map(|instances| instances.iter().map(Vec::as_slice).collect_vec()) + .collect_vec(); + let instances = instances.iter().map(Vec::as_slice).collect_vec(); + create_proof_checked::< + $commitment_scheme, + _, + $prover, + $verifier, + $verification_strategy, + $transcript_read, + $transcript_write, + $encoded_challenge, + _, + >( + $params, + $pk, + $circuits, + &instances, + &mut ChaCha20Rng::from_seed(Default::default()), + $finalize, + ) + }; + + Snark::new( + $protocol.clone(), + instances.into_iter().flatten().collect_vec(), + proof, + ) + }}; +} + +macro_rules! halo2_native_verify { + ( + $plonk_verifier:ty, + $params:expr, + $protocol:expr, + $instances:expr, + $transcript:expr, + $svk:expr, + $dk:expr + ) => {{ + use halo2_proofs::poly::commitment::ParamsProver; + use $crate::verifier::PlonkVerifier; + + let proof = + <$plonk_verifier>::read_proof($svk, $protocol, $instances, $transcript).unwrap(); + assert!(<$plonk_verifier>::verify($svk, $dk, $protocol, $instances, &proof).unwrap()) + }}; +} + +pub(crate) use {halo2_create_snark, halo2_native_verify, halo2_prepare}; diff --git a/src/protocol/halo2/test/circuit.rs b/src/system/halo2/test/circuit.rs similarity index 67% rename from src/protocol/halo2/test/circuit.rs rename to src/system/halo2/test/circuit.rs index a87cb997..87a005fb 100644 --- a/src/protocol/halo2/test/circuit.rs +++ b/src/system/halo2/test/circuit.rs @@ -1,3 +1,2 @@ pub mod maingate; -pub mod plookup; pub mod standard; diff --git a/src/system/halo2/test/circuit/maingate.rs b/src/system/halo2/test/circuit/maingate.rs new file mode 100644 index 00000000..82d63b5e --- /dev/null +++ b/src/system/halo2/test/circuit/maingate.rs @@ -0,0 +1,111 @@ +use crate::util::arithmetic::{CurveAffine, FieldExt}; +use halo2_proofs::{ + circuit::{floor_planner::V1, Layouter, Value}, + plonk::{Circuit, ConstraintSystem, Error}, +}; +use halo2_wrong_ecc::{ + maingate::{ + MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, RangeInstructions, + RegionCtx, + }, + BaseFieldEccChip, EccConfig, +}; +use rand::RngCore; + +#[derive(Clone)] +pub struct MainGateWithRangeConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, +} + +impl MainGateWithRangeConfig { + pub fn configure( + meta: &mut ConstraintSystem, + composition_bits: Vec, + overflow_bits: Vec, + ) -> Self { + let main_gate_config = MainGate::::configure(meta); + let range_config = + RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); + MainGateWithRangeConfig { + main_gate_config, + range_config, + } + } + + pub fn main_gate(&self) -> MainGate { + MainGate::new(self.main_gate_config.clone()) + } + + pub fn range_chip(&self) -> RangeChip { + RangeChip::new(self.range_config.clone()) + } + + pub fn ecc_chip( + &self, + ) -> BaseFieldEccChip { + BaseFieldEccChip::new(EccConfig::new( + self.range_config.clone(), + self.main_gate_config.clone(), + )) + } +} + +#[derive(Clone, Default)] +pub struct MainGateWithRange(Vec); + +impl MainGateWithRange { + pub fn new(inner: Vec) -> Self { + Self(inner) + } + + pub fn rand(mut rng: R) -> Self { + Self::new(vec![F::from(rng.next_u32() as u64)]) + } + + pub fn instances(&self) -> Vec> { + vec![self.0.clone()] + } +} + +impl Circuit for MainGateWithRange { + type Config = MainGateWithRangeConfig; + type FloorPlanner = V1; + + fn without_witnesses(&self) -> Self { + Self(vec![F::zero()]) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + MainGateWithRangeConfig::configure(meta, vec![8], vec![4, 7]) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + range_chip.load_table(&mut layouter)?; + + let a = layouter.assign_region( + || "", + |region| { + let mut ctx = RegionCtx::new(region, 0); + range_chip.decompose(&mut ctx, Value::known(F::from(u64::MAX)), 8, 64)?; + range_chip.decompose(&mut ctx, Value::known(F::from(u32::MAX as u64)), 8, 39)?; + let a = range_chip.assign(&mut ctx, Value::known(self.0[0]), 8, 68)?; + let b = main_gate.sub_sub_with_constant(&mut ctx, &a, &a, &a, F::from(2))?; + let cond = main_gate.assign_bit(&mut ctx, Value::known(F::one()))?; + main_gate.select(&mut ctx, &a, &b, &cond)?; + + Ok(a) + }, + )?; + + main_gate.expose_public(layouter, a, 0)?; + + Ok(()) + } +} diff --git a/src/protocol/halo2/test/circuit/standard.rs b/src/system/halo2/test/circuit/standard.rs similarity index 69% rename from src/protocol/halo2/test/circuit/standard.rs rename to src/system/halo2/test/circuit/standard.rs index 0773f360..90f30f2b 100644 --- a/src/protocol/halo2/test/circuit/standard.rs +++ b/src/system/halo2/test/circuit/standard.rs @@ -1,7 +1,7 @@ +use crate::util::arithmetic::FieldExt; use halo2_proofs::{ - arithmetic::FieldExt, circuit::{floor_planner::V1, Layouter, Value}, - plonk::{Advice, Any, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, poly::Rotation, }; use rand::RngCore; @@ -22,39 +22,29 @@ pub struct StandardPlonkConfig { impl StandardPlonkConfig { pub fn configure(meta: &mut ConstraintSystem) -> Self { - let a = meta.advice_column(); - let b = meta.advice_column(); - let c = meta.advice_column(); - - let q_a = meta.fixed_column(); - let q_b = meta.fixed_column(); - let q_c = meta.fixed_column(); - - let q_ab = meta.fixed_column(); - - let constant = meta.fixed_column(); + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); let instance = meta.instance_column(); - meta.enable_equality(a); - meta.enable_equality(b); - meta.enable_equality(c); - - meta.create_gate("", |meta| { - let [a, b, c, q_a, q_b, q_c, q_ab, constant, instance] = [ - a.into(), - b.into(), - c.into(), - q_a.into(), - q_b.into(), - q_c.into(), - q_ab.into(), - constant.into(), - instance.into(), - ] - .map(|column: Column| meta.query_any(column, Rotation::cur())); - - vec![q_a * a.clone() + q_b * b.clone() + q_c * c + q_ab * a * b + constant + instance] - }); + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); StandardPlonkConfig { a, diff --git a/src/system/halo2/test/kzg.rs b/src/system/halo2/test/kzg.rs new file mode 100644 index 00000000..0b071175 --- /dev/null +++ b/src/system/halo2/test/kzg.rs @@ -0,0 +1,120 @@ +use crate::{ + system::halo2::test::{read_or_create_srs, MainGateWithRange}, + util::arithmetic::{fe_to_limbs, CurveAffine, MultiMillerLoop}, +}; +use halo2_proofs::poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}; +use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + +mod native; + +#[cfg(feature = "loader_evm")] +mod evm; + +#[cfg(feature = "loader_halo2")] +mod halo2; + +pub const TESTDATA_DIR: &str = "./src/system/halo2/test/kzg/testdata"; + +pub const LIMBS: usize = 4; +pub const BITS: usize = 68; + +pub fn setup(k: u32) -> ParamsKZG { + ParamsKZG::::setup(k, ChaCha20Rng::from_seed(Default::default())) +} + +pub fn main_gate_with_range_with_mock_kzg_accumulator( +) -> MainGateWithRange { + let srs = read_or_create_srs(TESTDATA_DIR, 1, setup::); + let [g1, s_g1] = [srs.get_g()[0], srs.get_g()[1]].map(|point| point.coordinates().unwrap()); + MainGateWithRange::new( + [*s_g1.x(), *s_g1.y(), *g1.x(), *g1.y()] + .iter() + .cloned() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .collect(), + ) +} + +macro_rules! halo2_kzg_config { + ($zk:expr, $num_proof:expr) => { + $crate::system::halo2::Config::kzg() + .set_zk($zk) + .with_num_proof($num_proof) + }; + ($zk:expr, $num_proof:expr, $accumulator_indices:expr) => { + $crate::system::halo2::Config::kzg() + .set_zk($zk) + .with_num_proof($num_proof) + .with_accumulator_indices($accumulator_indices) + }; +} + +macro_rules! halo2_kzg_prepare { + ($k:expr, $config:expr, $create_circuit:expr) => {{ + use halo2_curves::bn256::Bn256; + use $crate::system::halo2::test::{ + halo2_prepare, + kzg::{setup, TESTDATA_DIR}, + }; + + halo2_prepare!(TESTDATA_DIR, $k, setup::, $config, $create_circuit) + }}; +} + +macro_rules! halo2_kzg_create_snark { + ( + $prover:ty, + $verifier:ty, + $transcript_read:ty, + $transcript_write:ty, + $encoded_challenge:ty, + $params:expr, + $pk:expr, + $protocol:expr, + $circuits:expr + ) => {{ + use halo2_proofs::poly::kzg::{commitment::KZGCommitmentScheme, strategy::SingleStrategy}; + use $crate::system::halo2::test::halo2_create_snark; + + halo2_create_snark!( + KZGCommitmentScheme<_>, + $prover, + $verifier, + SingleStrategy<_>, + $transcript_read, + $transcript_write, + $encoded_challenge, + |proof, _| proof, + $params, + $pk, + $protocol, + $circuits + ) + }}; +} + +macro_rules! halo2_kzg_native_verify { + ( + $plonk_verifier:ty, + $params:expr, + $protocol:expr, + $instances:expr, + $transcript:expr + ) => {{ + use $crate::system::halo2::test::halo2_native_verify; + + halo2_native_verify!( + $plonk_verifier, + $params, + $protocol, + $instances, + $transcript, + &$params.get_g()[0].into(), + &($params.g2(), $params.s_g2()).into() + ) + }}; +} + +pub(crate) use { + halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, halo2_kzg_prepare, +}; diff --git a/src/system/halo2/test/kzg/evm.rs b/src/system/halo2/test/kzg/evm.rs new file mode 100644 index 00000000..4ce850c1 --- /dev/null +++ b/src/system/halo2/test/kzg/evm.rs @@ -0,0 +1,138 @@ +use crate::{ + loader::native::NativeLoader, + pcs::kzg::{Bdfg21, Gwc19, Kzg, LimbsEncoding}, + system::halo2::{ + test::{ + kzg::{ + self, halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, + halo2_kzg_prepare, main_gate_with_range_with_mock_kzg_accumulator, BITS, LIMBS, + }, + StandardPlonk, + }, + transcript::evm::{ChallengeEvm, EvmTranscript}, + }, + verifier::Plonk, +}; +use halo2_curves::bn256::{Bn256, G1Affine}; +use halo2_proofs::poly::kzg::multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}; +use paste::paste; +use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + +macro_rules! halo2_kzg_evm_verify { + ($plonk_verifier:ty, $params:expr, $protocol:expr, $instances:expr, $proof:expr) => {{ + use halo2_curves::bn256::{Bn256, Fq, Fr}; + use halo2_proofs::poly::commitment::ParamsProver; + use std::rc::Rc; + use $crate::{ + loader::evm::{encode_calldata, execute, EvmLoader}, + system::halo2::{ + test::kzg::{BITS, LIMBS}, + transcript::evm::EvmTranscript, + }, + util::Itertools, + verifier::PlonkVerifier, + }; + + let loader = EvmLoader::new::(); + let runtime_code = { + let svk = $params.get_g()[0].into(); + let dk = ($params.g2(), $params.s_g2()).into(); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let instances = transcript.load_instances( + $instances + .iter() + .map(|instances| instances.len()) + .collect_vec(), + ); + let proof = <$plonk_verifier>::read_proof(&svk, $protocol, &instances, &mut transcript) + .unwrap(); + <$plonk_verifier>::verify(&svk, &dk, $protocol, &instances, &proof).unwrap(); + + loader.runtime_code() + }; + + let (accept, total_cost, costs) = + execute(runtime_code, encode_calldata($instances, &$proof)); + + loader.print_gas_metering(costs); + println!("Total gas cost: {}", total_cost); + + assert!(accept); + }}; +} + +macro_rules! test { + (@ $(#[$attr:meta],)* $prefix:ident, $name:ident, $k:expr, $config:expr, $create_circuit:expr, $prover:ty, $verifier:ty, $plonk_verifier:ty) => { + paste! { + $(#[$attr])* + fn []() { + let (params, pk, protocol, circuits) = halo2_kzg_prepare!( + $k, + $config, + $create_circuit + ); + let snark = halo2_kzg_create_snark!( + $prover, + $verifier, + EvmTranscript, + EvmTranscript, + ChallengeEvm<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + halo2_kzg_native_verify!( + $plonk_verifier, + params, + &snark.protocol, + &snark.instances, + &mut EvmTranscript::<_, NativeLoader, _, _>::new(snark.proof.as_slice()) + ); + halo2_kzg_evm_verify!( + $plonk_verifier, + params, + &snark.protocol, + &snark.instances, + snark.proof + ); + } + } + }; + ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + test!(@ #[test], shplonk, $name, $k, $config, $create_circuit, ProverSHPLONK<_>, VerifierSHPLONK<_>, Plonk, LimbsEncoding>); + test!(@ #[test], plonk, $name, $k, $config, $create_circuit, ProverGWC<_>, VerifierGWC<_>, Plonk, LimbsEncoding>); + }; + ($(#[$attr:meta],)* $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + test!(@ #[test] $(,#[$attr])*, plonk, $name, $k, $config, $create_circuit, ProverGWC<_>, VerifierGWC<_>, Plonk, LimbsEncoding>); + }; +} + +test!( + zk_standard_plonk_rand, + 9, + halo2_kzg_config!(true, 1), + StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) +); +test!( + zk_main_gate_with_range_with_mock_kzg_accumulator, + 9, + halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + main_gate_with_range_with_mock_kzg_accumulator::() +); +test!( + #[cfg(feature = "loader_halo2")], + #[ignore = "cause it requires 32GB memory to run"], + zk_accumulation_two_snark, + 22, + halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + kzg::halo2::Accumulation::two_snark() +); +test!( + #[cfg(feature = "loader_halo2")], + #[ignore = "cause it requires 32GB memory to run"], + zk_accumulation_two_snark_with_accumulator, + 22, + halo2_kzg_config!(true, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), + kzg::halo2::Accumulation::two_snark_with_accumulator() +); diff --git a/src/system/halo2/test/kzg/halo2.rs b/src/system/halo2/test/kzg/halo2.rs new file mode 100644 index 00000000..1bd332a0 --- /dev/null +++ b/src/system/halo2/test/kzg/halo2.rs @@ -0,0 +1,372 @@ +use crate::{ + loader, + loader::{ + halo2::test::{Snark, SnarkWitness}, + native::NativeLoader, + }, + pcs::{ + kzg::{ + Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgAsProvingKey, KzgAsVerifyingKey, + KzgSuccinctVerifyingKey, LimbsEncoding, + }, + AccumulationScheme, AccumulationSchemeProver, + }, + system::{ + self, + halo2::{ + test::{ + kzg::{ + halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, + halo2_kzg_prepare, BITS, LIMBS, + }, + MainGateWithRange, MainGateWithRangeConfig, StandardPlonk, + }, + transcript::halo2::ChallengeScalar, + }, + }, + util::{arithmetic::fe_to_limbs, Itertools}, + verifier::{self, PlonkVerifier}, +}; +use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use halo2_proofs::{ + circuit::{floor_planner::V1, Layouter, Value}, + plonk, + plonk::Circuit, + poly::{ + commitment::ParamsProver, + kzg::{ + commitment::ParamsKZG, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + }, + }, + transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, +}; +use halo2_wrong_ecc::{ + self, + integer::rns::Rns, + maingate::{MainGateInstructions, RangeInstructions, RegionCtx}, +}; +use paste::paste; +use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; +use std::{iter, rc::Rc}; + +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; +type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; +type PoseidonTranscript = + system::halo2::transcript::halo2::PoseidonTranscript; + +type Pcs = Kzg; +type Svk = KzgSuccinctVerifyingKey; +type As = KzgAs; +type AsPk = KzgAsProvingKey; +type AsVk = KzgAsVerifyingKey; +type Plonk = verifier::Plonk>; + +pub fn accumulate<'a>( + svk: &Svk, + loader: &Rc>, + snarks: &[SnarkWitness], + as_vk: &AsVk, + as_proof: Value<&'_ [u8]>, +) -> KzgAccumulator>> { + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances + .iter() + .map(|instance| loader.assign_scalar(*instance)) + .collect_vec() + }) + .collect_vec() + }; + + let mut accumulators = snarks + .iter() + .flat_map(|snark| { + let instances = assign_instances(&snark.instances); + let mut transcript = + PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = + Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + }) + .collect_vec(); + + let acccumulator = if accumulators.len() > 1 { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = As::read_proof(as_vk, &accumulators, &mut transcript).unwrap(); + As::verify(as_vk, &accumulators, &proof).unwrap() + } else { + accumulators.pop().unwrap() + }; + + acccumulator +} + +pub struct Accumulation { + svk: Svk, + snarks: Vec>, + instances: Vec, + as_vk: AsVk, + as_proof: Value>, +} + +impl Accumulation { + pub fn accumulator_indices() -> Vec<(usize, usize)> { + (0..4 * LIMBS).map(|idx| (0, idx)).collect() + } + + pub fn new( + params: &ParamsKZG, + snarks: impl IntoIterator>, + ) -> Self { + let svk = params.get_g()[0].into(); + let snarks = snarks.into_iter().collect_vec(); + + let mut accumulators = snarks + .iter() + .flat_map(|snark| { + let mut transcript = + PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }) + .collect_vec(); + + let as_pk = AsPk::new(Some((params.get_g()[0], params.get_g()[1]))); + let (accumulator, as_proof) = if accumulators.len() > 1 { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = As::create_proof( + &as_pk, + &accumulators, + &mut transcript, + ChaCha20Rng::from_seed(Default::default()), + ) + .unwrap(); + (accumulator, Value::known(transcript.finalize())) + } else { + (accumulators.pop().unwrap(), Value::unknown()) + }; + + let KzgAccumulator { lhs, rhs } = accumulator; + let instances = [lhs.x, lhs.y, rhs.x, rhs.y] + .map(fe_to_limbs::<_, _, LIMBS, BITS>) + .concat(); + + Self { + svk, + snarks: snarks.into_iter().map_into().collect(), + instances, + as_vk: as_pk.vk(), + as_proof, + } + } + + pub fn two_snark() -> Self { + let (params, snark1) = { + const K: u32 = 9; + let (params, pk, protocol, circuits) = halo2_kzg_prepare!( + K, + halo2_kzg_config!(true, 1), + StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) + ); + let snark = halo2_kzg_create_snark!( + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + PoseidonTranscript<_, _>, + PoseidonTranscript<_, _>, + ChallengeScalar<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + (params, snark) + }; + let snark2 = { + const K: u32 = 9; + let (params, pk, protocol, circuits) = halo2_kzg_prepare!( + K, + halo2_kzg_config!(true, 1), + MainGateWithRange::rand(ChaCha20Rng::from_seed(Default::default())) + ); + halo2_kzg_create_snark!( + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + PoseidonTranscript<_, _>, + PoseidonTranscript<_, _>, + ChallengeScalar<_>, + ¶ms, + &pk, + &protocol, + &circuits + ) + }; + Self::new(¶ms, [snark1, snark2]) + } + + pub fn two_snark_with_accumulator() -> Self { + let (params, pk, protocol, circuits) = { + const K: u32 = 22; + halo2_kzg_prepare!( + K, + halo2_kzg_config!(true, 2, Self::accumulator_indices()), + Self::two_snark() + ) + }; + let snark = halo2_kzg_create_snark!( + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + PoseidonTranscript<_, _>, + PoseidonTranscript<_, _>, + ChallengeScalar<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + Self::new(¶ms, [snark]) + } + + pub fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + pub fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } +} + +impl Circuit for Accumulation { + type Config = MainGateWithRangeConfig; + type FloorPlanner = V1; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + snarks: self + .snarks + .iter() + .map(SnarkWitness::without_witnesses) + .collect(), + instances: Vec::new(), + as_vk: self.as_vk, + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + MainGateWithRangeConfig::configure( + meta, + vec![BITS / LIMBS], + Rns::::construct().overflow_lengths(), + ) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + + range_chip.load_table(&mut layouter)?; + + let (lhs, rhs) = layouter.assign_region( + || "", + |region| { + let ctx = RegionCtx::new(region, 0); + + let ecc_chip = config.ecc_chip(); + let loader = Halo2Loader::new(ecc_chip, ctx); + let KzgAccumulator { lhs, rhs } = accumulate( + &self.svk, + &loader, + &self.snarks, + &self.as_vk, + self.as_proof(), + ); + + loader.print_row_metering(); + println!("Total row cost: {}", loader.ctx().offset()); + + Ok((lhs.assigned(), rhs.assigned())) + }, + )?; + + for (limb, row) in iter::empty() + .chain(lhs.x().limbs()) + .chain(lhs.y().limbs()) + .chain(rhs.x().limbs()) + .chain(rhs.y().limbs()) + .zip(0..) + { + main_gate.expose_public(layouter.namespace(|| ""), limb.into(), row)?; + } + + Ok(()) + } +} + +macro_rules! test { + (@ $(#[$attr:meta],)* $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + paste! { + $(#[$attr])* + fn []() { + let (params, pk, protocol, circuits) = halo2_kzg_prepare!( + $k, + $config, + $create_circuit + ); + let snark = halo2_kzg_create_snark!( + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + Blake2bWrite<_, _, _>, + Blake2bRead<_, _, _>, + Challenge255<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + halo2_kzg_native_verify!( + Plonk, + params, + &snark.protocol, + &snark.instances, + &mut Blake2bRead::<_, G1Affine, _>::init(snark.proof.as_slice()) + ); + } + } + }; + ($name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + test!(@ #[test], $name, $k, $config, $create_circuit); + }; + ($(#[$attr:meta],)* $name:ident, $k:expr, $config:expr, $create_circuit:expr) => { + test!(@ #[test] $(,#[$attr])*, $name, $k, $config, $create_circuit); + }; +} + +test!( + #[ignore = "cause it requires 32GB memory to run"], + zk_accumulation_two_snark, + 22, + halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), + Accumulation::two_snark() +); +test!( + #[ignore = "cause it requires 32GB memory to run"], + zk_accumulation_two_snark_with_accumulator, + 22, + halo2_kzg_config!(true, 1, Accumulation::accumulator_indices()), + Accumulation::two_snark_with_accumulator() +); diff --git a/src/protocol/halo2/test/kzg/native.rs b/src/system/halo2/test/kzg/native.rs similarity index 50% rename from src/protocol/halo2/test/kzg/native.rs rename to src/system/halo2/test/kzg/native.rs index 4273e7d0..e52ceb38 100644 --- a/src/protocol/halo2/test/kzg/native.rs +++ b/src/system/halo2/test/kzg/native.rs @@ -1,61 +1,56 @@ use crate::{ - collect_slice, halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, - halo2_kzg_prepare, - protocol::halo2::test::{ + pcs::kzg::{Bdfg21, Gwc19, Kzg, LimbsEncoding}, + system::halo2::test::{ kzg::{ - main_gate_with_plookup_with_mock_kzg_accumulator, - main_gate_with_range_with_mock_kzg_accumulator, LIMBS, + halo2_kzg_config, halo2_kzg_create_snark, halo2_kzg_native_verify, halo2_kzg_prepare, + main_gate_with_range_with_mock_kzg_accumulator, BITS, LIMBS, }, StandardPlonk, }, - scheme::kzg::{PlonkAccumulationScheme, ShplonkAccumulationScheme}, + verifier::Plonk, }; -use halo2_curves::bn256::G1Affine; +use halo2_curves::bn256::{Bn256, G1Affine}; use halo2_proofs::{ - poly::kzg::{ - multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, - strategy::AccumulatorStrategy, - }, + poly::kzg::multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, }; use paste::paste; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; macro_rules! test { - (@ $prefix:ident, $name:ident, $k:expr, $config:expr, $create_cirucit:expr, $prover:ty, $verifier:ty, $scheme:ty) => { + (@ $prefix:ident, $name:ident, $k:expr, $config:expr, $create_cirucit:expr, $prover:ty, $verifier:ty, $plonk_verifier:ty) => { paste! { #[test] - fn []() { + fn []() { let (params, pk, protocol, circuits) = halo2_kzg_prepare!( $k, $config, $create_cirucit ); let snark = halo2_kzg_create_snark!( - ¶ms, - &pk, - &protocol, - &circuits, $prover, $verifier, - AccumulatorStrategy<_>, Blake2bWrite<_, _, _>, Blake2bRead<_, _, _>, - Challenge255<_> + Challenge255<_>, + ¶ms, + &pk, + &protocol, + &circuits ); halo2_kzg_native_verify!( + $plonk_verifier, params, &snark.protocol, - snark.statements, - $scheme, + &snark.instances, &mut Blake2bRead::<_, G1Affine, _>::init(snark.proof.as_slice()) ); } } }; ($name:ident, $k:expr, $config:expr, $create_cirucit:expr) => { - test!(@ shplonk, $name, $k, $config, $create_cirucit, ProverSHPLONK<_>, VerifierSHPLONK<_>, ShplonkAccumulationScheme); - test!(@ plonk, $name, $k, $config, $create_cirucit, ProverGWC<_>, VerifierGWC<_>, PlonkAccumulationScheme); + test!(@ shplonk, $name, $k, $config, $create_cirucit, ProverSHPLONK<_>, VerifierSHPLONK<_>, Plonk, LimbsEncoding>); + test!(@ plonk, $name, $k, $config, $create_cirucit, ProverGWC<_>, VerifierGWC<_>, Plonk, LimbsEncoding>); } } @@ -63,7 +58,7 @@ test!( zk_standard_plonk_rand, 9, halo2_kzg_config!(true, 2), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) + StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) ); test!( zk_main_gate_with_range_with_mock_kzg_accumulator, @@ -71,21 +66,3 @@ test!( halo2_kzg_config!(true, 2, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), main_gate_with_range_with_mock_kzg_accumulator::() ); -test!( - standard_plonk_rand, - 9, - halo2_kzg_config!(false, 2), - StandardPlonk::<_>::rand(ChaCha20Rng::from_seed(Default::default())) -); -test!( - main_gate_with_range_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(false, 2, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_range_with_mock_kzg_accumulator::() -); -test!( - main_gate_with_plookup_with_mock_kzg_accumulator, - 9, - halo2_kzg_config!(false, 1, (0..4 * LIMBS).map(|idx| (0, idx)).collect()), - main_gate_with_plookup_with_mock_kzg_accumulator::(9) -); diff --git a/src/system/halo2/transcript.rs b/src/system/halo2/transcript.rs new file mode 100644 index 00000000..2200bbf4 --- /dev/null +++ b/src/system/halo2/transcript.rs @@ -0,0 +1,82 @@ +use crate::{ + loader::native::{self, NativeLoader}, + util::{ + arithmetic::CurveAffine, + transcript::{Transcript, TranscriptRead, TranscriptWrite}, + }, + Error, +}; +use halo2_proofs::transcript::{Blake2bRead, Blake2bWrite, Challenge255}; +use std::io::{Read, Write}; + +#[cfg(feature = "loader_evm")] +pub mod evm; + +#[cfg(feature = "loader_halo2")] +pub mod halo2; + +impl Transcript for Blake2bRead> { + fn loader(&self) -> &NativeLoader { + &native::LOADER + } + + fn squeeze_challenge(&mut self) -> C::Scalar { + *halo2_proofs::transcript::Transcript::squeeze_challenge_scalar::(self) + } + + fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { + halo2_proofs::transcript::Transcript::common_point(self, *ec_point) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } + + fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { + halo2_proofs::transcript::Transcript::common_scalar(self, *scalar) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } +} + +impl TranscriptRead + for Blake2bRead> +{ + fn read_scalar(&mut self) -> Result { + halo2_proofs::transcript::TranscriptRead::read_scalar(self) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } + + fn read_ec_point(&mut self) -> Result { + halo2_proofs::transcript::TranscriptRead::read_point(self) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } +} + +impl Transcript for Blake2bWrite> { + fn loader(&self) -> &NativeLoader { + &native::LOADER + } + + fn squeeze_challenge(&mut self) -> C::Scalar { + *halo2_proofs::transcript::Transcript::squeeze_challenge_scalar::(self) + } + + fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { + halo2_proofs::transcript::Transcript::common_point(self, *ec_point) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } + + fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { + halo2_proofs::transcript::Transcript::common_scalar(self, *scalar) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } +} + +impl TranscriptWrite for Blake2bWrite, C, Challenge255> { + fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error> { + halo2_proofs::transcript::TranscriptWrite::write_scalar(self, scalar) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } + + fn write_ec_point(&mut self, ec_point: C) -> Result<(), Error> { + halo2_proofs::transcript::TranscriptWrite::write_point(self, ec_point) + .map_err(|err| Error::Transcript(err.kind(), err.to_string())) + } +} diff --git a/src/system/halo2/transcript/evm.rs b/src/system/halo2/transcript/evm.rs new file mode 100644 index 00000000..77aca5cf --- /dev/null +++ b/src/system/halo2/transcript/evm.rs @@ -0,0 +1,400 @@ +use crate::{ + loader::{ + evm::{loader::Value, u256_to_fe, EcPoint, EvmLoader, MemoryChunk, Scalar}, + native::{self, NativeLoader}, + Loader, + }, + util::{ + arithmetic::{Coordinates, CurveAffine, PrimeField}, + hash::{Digest, Keccak256}, + transcript::{Transcript, TranscriptRead}, + Itertools, + }, + Error, +}; +use ethereum_types::U256; +use halo2_proofs::transcript::EncodedChallenge; +use std::{ + io::{self, Read, Write}, + iter, + marker::PhantomData, + rc::Rc, +}; +pub struct EvmTranscript, S, B> { + loader: L, + stream: S, + buf: B, + _marker: PhantomData, +} + +impl EvmTranscript, usize, MemoryChunk> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + pub fn new(loader: Rc) -> Self { + let ptr = loader.allocate(0x20); + assert_eq!(ptr, 0); + let mut buf = MemoryChunk::new(ptr); + buf.extend(0x20); + Self { + loader, + stream: 0, + buf, + _marker: PhantomData, + } + } + + pub fn load_instances(&mut self, num_instance: Vec) -> Vec> { + num_instance + .into_iter() + .map(|len| { + iter::repeat_with(|| { + let scalar = self.loader.calldataload_scalar(self.stream); + self.stream += 0x20; + scalar + }) + .take(len) + .collect_vec() + }) + .collect() + } +} + +impl Transcript> for EvmTranscript, usize, MemoryChunk> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn loader(&self) -> &Rc { + &self.loader + } + + fn squeeze_challenge(&mut self) -> Scalar { + let len = if self.buf.len() == 0x20 { + assert_eq!(self.loader.ptr(), self.buf.end()); + self.loader + .code_mut() + .push(1) + .push(self.buf.end()) + .mstore8(); + 0x21 + } else { + self.buf.len() + }; + let hash_ptr = self.loader.keccak256(self.buf.ptr(), len); + + let challenge_ptr = self.loader.allocate(0x20); + let dup_hash_ptr = self.loader.allocate(0x20); + self.loader + .code_mut() + .push(hash_ptr) + .mload() + .push(self.loader.scalar_modulus()) + .dup(1) + .r#mod() + .push(challenge_ptr) + .mstore() + .push(dup_hash_ptr) + .mstore(); + + self.buf.reset(dup_hash_ptr); + self.buf.extend(0x20); + + self.loader.scalar(Value::Memory(challenge_ptr)) + } + + fn common_ec_point(&mut self, ec_point: &EcPoint) -> Result<(), Error> { + if let Value::Memory(ptr) = ec_point.value() { + assert_eq!(self.buf.end(), ptr); + self.buf.extend(0x40); + } else { + unreachable!() + } + Ok(()) + } + + fn common_scalar(&mut self, scalar: &Scalar) -> Result<(), Error> { + match scalar.value() { + Value::Constant(_) if self.buf.ptr() == 0 => { + self.loader.copy_scalar(scalar, self.buf.ptr()); + } + Value::Memory(ptr) => { + assert_eq!(self.buf.end(), ptr); + self.buf.extend(0x20); + } + _ => unreachable!(), + } + Ok(()) + } +} + +impl TranscriptRead> for EvmTranscript, usize, MemoryChunk> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn read_scalar(&mut self) -> Result { + let scalar = self.loader.calldataload_scalar(self.stream); + self.stream += 0x20; + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result { + let ec_point = self.loader.calldataload_ec_point(self.stream); + self.stream += 0x40; + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +impl EvmTranscript> +where + C: CurveAffine, +{ + pub fn new(stream: S) -> Self { + Self { + loader: NativeLoader, + stream, + buf: Vec::new(), + _marker: PhantomData, + } + } +} + +impl Transcript for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn loader(&self) -> &NativeLoader { + &native::LOADER + } + + fn squeeze_challenge(&mut self) -> C::Scalar { + let data = self + .buf + .iter() + .cloned() + .chain(if self.buf.len() == 0x20 { + Some(1) + } else { + None + }) + .collect_vec(); + let hash: [u8; 32] = Keccak256::digest(data).into(); + self.buf = hash.to_vec(); + u256_to_fe(U256::from_big_endian(hash.as_slice())) + } + + fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { + let coordinates = + Option::>::from(ec_point.coordinates()).ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Cannot write points at infinity to the transcript".to_string(), + ) + })?; + + [coordinates.x(), coordinates.y()].map(|coordinate| { + self.buf + .extend(coordinate.to_repr().as_ref().iter().rev().cloned()); + }); + + Ok(()) + } + + fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { + self.buf.extend(scalar.to_repr().as_ref().iter().rev()); + + Ok(()) + } +} + +impl TranscriptRead for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, + S: Read, +{ + fn read_scalar(&mut self) -> Result { + let mut data = [0; 32]; + self.stream + .read_exact(data.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + data.reverse(); + let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid scalar encoding in proof".to_string(), + ) + })?; + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result { + let [mut x, mut y] = [::Repr::default(); 2]; + for repr in [&mut x, &mut y] { + self.stream + .read_exact(repr.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + repr.as_mut().reverse(); + } + let x = Option::from(::from_repr(x)); + let y = Option::from(::from_repr(y)); + let ec_point = x + .zip(y) + .and_then(|(x, y)| Option::from(C::from_xy(x, y))) + .ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid elliptic curve point encoding in proof".to_string(), + ) + })?; + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +impl EvmTranscript> +where + C: CurveAffine, + S: Write, +{ + pub fn stream_mut(&mut self) -> &mut S { + &mut self.stream + } + + pub fn finalize(self) -> S { + self.stream + } +} + +pub struct ChallengeEvm(C::Scalar) +where + C: CurveAffine, + C::Scalar: PrimeField; + +impl EncodedChallenge for ChallengeEvm +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + type Input = [u8; 32]; + + fn new(challenge_input: &[u8; 32]) -> Self { + ChallengeEvm(u256_to_fe(U256::from_big_endian(challenge_input))) + } + + fn get_scalar(&self) -> C::Scalar { + self.0 + } +} + +impl halo2_proofs::transcript::Transcript> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn squeeze_challenge(&mut self) -> ChallengeEvm { + ChallengeEvm(Transcript::squeeze_challenge(self)) + } + + fn common_point(&mut self, ec_point: C) -> io::Result<()> { + match Transcript::common_ec_point(self, &ec_point) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } + + fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + match Transcript::common_scalar(self, &scalar) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } +} + +impl halo2_proofs::transcript::TranscriptRead> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn read_point(&mut self) -> io::Result { + match TranscriptRead::read_ec_point(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } + + fn read_scalar(&mut self) -> io::Result { + match TranscriptRead::read_scalar(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } +} + +impl halo2_proofs::transcript::TranscriptReadBuffer> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn init(reader: R) -> Self { + Self::new(reader) + } +} + +impl halo2_proofs::transcript::TranscriptWrite> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn write_point(&mut self, ec_point: C) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_point(self, ec_point)?; + let coords: Coordinates = Option::from(ec_point.coordinates()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::Other, + "Cannot write points at infinity to the transcript", + ) + })?; + let mut x = coords.x().to_repr(); + let mut y = coords.y().to_repr(); + x.as_mut().reverse(); + y.as_mut().reverse(); + self.stream_mut().write_all(x.as_ref())?; + self.stream_mut().write_all(y.as_ref()) + } + + fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_scalar(self, scalar)?; + let mut data = scalar.to_repr(); + data.as_mut().reverse(); + self.stream_mut().write_all(data.as_ref()) + } +} + +impl halo2_proofs::transcript::TranscriptWriterBuffer> + for EvmTranscript> +where + C: CurveAffine, + C::Scalar: PrimeField, +{ + fn init(writer: W) -> Self { + Self::new(writer) + } + + fn finalize(self) -> W { + self.finalize() + } +} diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs new file mode 100644 index 00000000..43701bd7 --- /dev/null +++ b/src/system/halo2/transcript/halo2.rs @@ -0,0 +1,439 @@ +use crate::{ + loader::{ + halo2::{self, EcPoint, EccInstructions, Halo2Loader, IntegerInstructions, Scalar}, + native::{self, NativeLoader}, + Loader, ScalarLoader, + }, + util::{ + arithmetic::{fe_from_big, fe_to_big, CurveAffine, FieldExt, PrimeField}, + hash::Poseidon, + transcript::{Transcript, TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, +}; +use halo2_proofs::{ + circuit::{AssignedCell, Value}, + transcript::EncodedChallenge, +}; +use std::{ + io::{self, Read, Write}, + marker::PhantomData, + rc::Rc, +}; + +pub trait EncodeNative<'a, C: CurveAffine, N: FieldExt>: EccInstructions<'a, C> { + fn encode_native( + &self, + ctx: &mut Self::Context, + ec_point: &Self::AssignedEcPoint, + ) -> Result>, Error>; +} + +pub struct PoseidonTranscript< + C: CurveAffine, + L: Loader, + S, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +> { + loader: L, + stream: S, + buf: Poseidon>::LoadedScalar, T, RATE>, + _marker: PhantomData, +} + +impl< + 'a, + C: CurveAffine, + R: Read, + EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > PoseidonTranscript>, Value, T, RATE, R_F, R_P> +{ + pub fn new(loader: &Rc>, stream: Value) -> Self { + Self { + loader: loader.clone(), + stream, + buf: Poseidon::new(loader.clone(), R_F, R_P), + _marker: PhantomData, + } + } +} + +impl< + 'a, + C: CurveAffine, + R: Read, + EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > Transcript>> + for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +{ + fn loader(&self) -> &Rc> { + &self.loader + } + + fn squeeze_challenge(&mut self) -> Scalar<'a, C, EccChip> { + self.buf.squeeze() + } + + fn common_scalar(&mut self, scalar: &Scalar<'a, C, EccChip>) -> Result<(), Error> { + self.buf.update(&[scalar.clone()]); + Ok(()) + } + + fn common_ec_point(&mut self, ec_point: &EcPoint<'a, C, EccChip>) -> Result<(), Error> { + let encoded = self + .loader + .ecc_chip() + .encode_native(&mut self.loader.ctx_mut(), &ec_point.assigned()) + .map(|encoded| { + encoded + .into_iter() + .map(|encoded| self.loader.scalar(halo2::loader::Value::Assigned(encoded))) + .collect_vec() + }) + .map_err(|_| Error::Transcript(io::ErrorKind::Other, "".to_string()))?; + self.buf.update(&encoded); + Ok(()) + } +} + +impl< + 'a, + C: CurveAffine, + R: Read, + EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > TranscriptRead>> + for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +{ + fn read_scalar(&mut self) -> Result, Error> { + let scalar = self.stream.as_mut().and_then(|stream| { + let mut data = ::Repr::default(); + if stream.read_exact(data.as_mut()).is_err() { + return Value::unknown(); + } + Option::::from(C::Scalar::from_repr(data)) + .map(|scalar| Value::known(self.loader.scalar_chip().integer(scalar))) + .unwrap_or_else(Value::unknown) + }); + let scalar = self.loader.assign_scalar(scalar); + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result, Error> { + let ec_point = self.stream.as_mut().and_then(|stream| { + let mut compressed = C::Repr::default(); + if stream.read_exact(compressed.as_mut()).is_err() { + return Value::unknown(); + } + Option::::from(C::from_bytes(&compressed)) + .map(Value::known) + .unwrap_or_else(Value::unknown) + }); + let ec_point = self.loader.assign_ec_point(ec_point); + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +// + +impl + PoseidonTranscript +{ + pub fn new(stream: S) -> Self { + Self { + loader: NativeLoader, + stream, + buf: Poseidon::new(NativeLoader, R_F, R_P), + _marker: PhantomData, + } + } +} + +impl + Transcript for PoseidonTranscript +{ + fn loader(&self) -> &NativeLoader { + &native::LOADER + } + + fn squeeze_challenge(&mut self) -> C::Scalar { + self.buf.squeeze() + } + + fn common_scalar(&mut self, scalar: &C::Scalar) -> Result<(), Error> { + self.buf.update(&[*scalar]); + Ok(()) + } + + fn common_ec_point(&mut self, ec_point: &C) -> Result<(), Error> { + let encoded: Vec<_> = Option::from(ec_point.coordinates().map(|coordinates| { + [coordinates.x(), coordinates.y()] + .into_iter() + .map(|fe| fe_from_big(fe_to_big(*fe))) + .collect_vec() + })) + .ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid elliptic curve point encoding in proof".to_string(), + ) + })?; + self.buf.update(&encoded); + Ok(()) + } +} + +impl< + C: CurveAffine, + R: Read, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > TranscriptRead + for PoseidonTranscript +{ + fn read_scalar(&mut self) -> Result { + let mut data = ::Repr::default(); + self.stream + .read_exact(data.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid scalar encoding in proof".to_string(), + ) + })?; + self.common_scalar(&scalar)?; + Ok(scalar) + } + + fn read_ec_point(&mut self) -> Result { + let mut data = C::Repr::default(); + self.stream + .read_exact(data.as_mut()) + .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; + let ec_point = Option::::from(C::from_bytes(&data)).ok_or_else(|| { + Error::Transcript( + io::ErrorKind::Other, + "Invalid elliptic curve point encoding in proof".to_string(), + ) + })?; + self.common_ec_point(&ec_point)?; + Ok(ec_point) + } +} + +impl< + C: CurveAffine, + W: Write, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > PoseidonTranscript +{ + pub fn stream_mut(&mut self) -> &mut W { + &mut self.stream + } + + pub fn finalize(self) -> W { + self.stream + } +} + +impl< + C: CurveAffine, + W: Write, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > TranscriptWrite for PoseidonTranscript +{ + fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error> { + self.common_scalar(&scalar)?; + let data = scalar.to_repr(); + self.stream_mut().write_all(data.as_ref()).map_err(|err| { + Error::Transcript( + err.kind(), + "Failed to write scalar to transcript".to_string(), + ) + }) + } + + fn write_ec_point(&mut self, ec_point: C) -> Result<(), Error> { + self.common_ec_point(&ec_point)?; + let data = ec_point.to_bytes(); + self.stream_mut().write_all(data.as_ref()).map_err(|err| { + Error::Transcript( + err.kind(), + "Failed to write elliptic curve to transcript".to_string(), + ) + }) + } +} + +pub struct ChallengeScalar(C::Scalar); + +impl EncodedChallenge for ChallengeScalar { + type Input = C::Scalar; + + fn new(challenge_input: &C::Scalar) -> Self { + ChallengeScalar(*challenge_input) + } + + fn get_scalar(&self) -> C::Scalar { + self.0 + } +} + +impl + halo2_proofs::transcript::Transcript> + for PoseidonTranscript +{ + fn squeeze_challenge(&mut self) -> ChallengeScalar { + ChallengeScalar::new(&Transcript::squeeze_challenge(self)) + } + + fn common_point(&mut self, ec_point: C) -> io::Result<()> { + match Transcript::common_ec_point(self, &ec_point) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } + + fn common_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + match Transcript::common_scalar(self, &scalar) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + _ => Ok(()), + } + } +} + +impl< + C: CurveAffine, + R: Read, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > halo2_proofs::transcript::TranscriptRead> + for PoseidonTranscript +{ + fn read_point(&mut self) -> io::Result { + match TranscriptRead::read_ec_point(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } + + fn read_scalar(&mut self) -> io::Result { + match TranscriptRead::read_scalar(self) { + Err(Error::Transcript(kind, msg)) => Err(io::Error::new(kind, msg)), + Err(_) => unreachable!(), + Ok(value) => Ok(value), + } + } +} + +impl< + C: CurveAffine, + R: Read, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > halo2_proofs::transcript::TranscriptReadBuffer> + for PoseidonTranscript +{ + fn init(reader: R) -> Self { + Self::new(reader) + } +} + +impl< + C: CurveAffine, + W: Write, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > halo2_proofs::transcript::TranscriptWrite> + for PoseidonTranscript +{ + fn write_point(&mut self, ec_point: C) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_point( + self, ec_point, + )?; + let data = ec_point.to_bytes(); + self.stream_mut().write_all(data.as_ref()) + } + + fn write_scalar(&mut self, scalar: C::Scalar) -> io::Result<()> { + halo2_proofs::transcript::Transcript::>::common_scalar(self, scalar)?; + let data = scalar.to_repr(); + self.stream_mut().write_all(data.as_ref()) + } +} + +impl< + C: CurveAffine, + W: Write, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + > halo2_proofs::transcript::TranscriptWriterBuffer> + for PoseidonTranscript +{ + fn init(writer: W) -> Self { + Self::new(writer) + } + + fn finalize(self) -> W { + self.finalize() + } +} + +mod halo2_wrong { + use crate::system::halo2::transcript::halo2::EncodeNative; + use halo2_curves::CurveAffine; + use halo2_proofs::circuit::AssignedCell; + use halo2_wrong_ecc::BaseFieldEccChip; + + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> EncodeNative<'a, C, C::Scalar> + for BaseFieldEccChip + { + fn encode_native( + &self, + _: &mut Self::Context, + ec_point: &Self::AssignedEcPoint, + ) -> Result>, crate::Error> { + Ok(vec![ + ec_point.x().native().clone(), + ec_point.y().native().clone(), + ]) + } + } +} diff --git a/src/util.rs b/src/util.rs index eb38e56a..3d5d0d79 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,34 +1,7 @@ -mod arithmetic; -mod expression; -mod transcript; +pub mod arithmetic; +pub mod hash; +pub mod msm; +pub mod protocol; +pub mod transcript; -pub use arithmetic::{ - batch_invert, batch_invert_and_mul, fe_from_limbs, fe_to_limbs, BatchInvert, Curve, Domain, - Field, FieldOps, Fraction, Group, GroupEncoding, GroupOps, PrimeCurveAffine, PrimeField, - Rotation, UncompressedEncoding, -}; -pub use expression::{CommonPolynomial, CommonPolynomialEvaluation, Expression, Query}; -pub use transcript::{Transcript, TranscriptRead}; - -pub use itertools::{EitherOrBoth, Itertools}; - -#[macro_export] -macro_rules! collect_slice { - ($vec:ident) => { - use $crate::util::Itertools; - - let $vec = $vec.iter().map(|vec| vec.as_slice()).collect_vec(); - }; - ($vec:ident, 2) => { - use $crate::util::Itertools; - - let $vec = $vec - .iter() - .map(|vec| { - collect_slice!(vec); - vec - }) - .collect_vec(); - let $vec = $vec.iter().map(|vec| vec.as_slice()).collect_vec(); - }; -} +pub(crate) use itertools::Itertools; diff --git a/src/util/arithmetic.rs b/src/util/arithmetic.rs index 0ba1929b..b9a5c7c6 100644 --- a/src/util/arithmetic.rs +++ b/src/util/arithmetic.rs @@ -8,43 +8,34 @@ use std::{ ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; -pub use ff::{BatchInvert, Field, PrimeField}; -pub use group::{prime::PrimeCurveAffine, Curve, Group, GroupEncoding}; +pub use halo2_curves::{ + group::{ + ff::{BatchInvert, Field, PrimeField}, + prime::PrimeCurveAffine, + Curve, Group, GroupEncoding, + }, + pairing::MillerLoopResult, + Coordinates, CurveAffine, CurveExt, FieldExt, +}; + +pub trait MultiMillerLoop: halo2_curves::pairing::MultiMillerLoop + Debug {} + +impl MultiMillerLoop for M {} -pub trait GroupOps: +pub trait FieldOps: Sized + + Neg + Add + Sub - + Neg + + Mul + for<'a> Add<&'a Self, Output = Self> + for<'a> Sub<&'a Self, Output = Self> + + for<'a> Mul<&'a Self, Output = Self> + AddAssign + SubAssign + + MulAssign + for<'a> AddAssign<&'a Self> + for<'a> SubAssign<&'a Self> -{ -} - -impl GroupOps for T where - T: Sized - + Add - + Sub - + Neg - + for<'a> Add<&'a Self, Output = Self> - + for<'a> Sub<&'a Self, Output = Self> - + AddAssign - + SubAssign - + for<'a> AddAssign<&'a Self> - + for<'a> SubAssign<&'a Self> -{ -} - -pub trait FieldOps: - Sized - + GroupOps - + Mul - + for<'a> Mul<&'a Self, Output = Self> - + MulAssign + for<'a> MulAssign<&'a Self> { fn invert(&self) -> Option; @@ -78,12 +69,13 @@ pub fn batch_invert(values: &mut [F]) { batch_invert_and_mul(values, &F::one()) } -pub trait UncompressedEncoding: Sized { - type Uncompressed: AsRef<[u8]> + AsMut<[u8]>; +pub fn root_of_unity(k: usize) -> F { + assert!(k <= F::S as usize); - fn to_uncompressed(&self) -> Self::Uncompressed; - - fn from_uncompressed(uncompressed: Self::Uncompressed) -> Option; + iter::successors(Some(F::root_of_unity()), |acc| Some(acc.square())) + .take(F::S as usize - k + 1) + .last() + .unwrap() } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -119,15 +111,9 @@ pub struct Domain { } impl Domain { - pub fn new(k: usize) -> Self { - assert!(k <= F::S as usize); - + pub fn new(k: usize, gen: F) -> Self { let n = 1 << k; let n_inv = F::from(n as u64).invert().unwrap(); - let gen = iter::successors(Some(F::root_of_unity()), |acc| Some(acc.square())) - .take(F::S as usize - k + 1) - .last() - .unwrap(); let gen_inv = gen.invert().unwrap(); Self { @@ -149,30 +135,33 @@ impl Domain { } #[derive(Clone, Debug)] -pub struct Fraction { - numer: Option, - denom: F, +pub struct Fraction { + numer: Option, + denom: T, + eval: Option, inv: bool, } -impl Fraction { - pub fn new(numer: F, denom: F) -> Self { +impl Fraction { + pub fn new(numer: T, denom: T) -> Self { Self { numer: Some(numer), denom, + eval: None, inv: false, } } - pub fn one_over(denom: F) -> Self { + pub fn one_over(denom: T) -> Self { Self { numer: None, denom, + eval: None, inv: false, } } - pub fn denom(&self) -> Option<&F> { + pub fn denom(&self) -> Option<&T> { if !self.inv { Some(&self.denom) } else { @@ -180,7 +169,7 @@ impl Fraction { } } - pub fn denom_mut(&mut self) -> Option<&mut F> { + pub fn denom_mut(&mut self) -> Option<&mut T> { if !self.inv { self.inv = true; Some(&mut self.denom) @@ -190,21 +179,35 @@ impl Fraction { } } -impl Fraction { - pub fn evaluate(&self) -> F { - let denom = if self.inv { - self.denom.clone() - } else { - self.denom.invert().unwrap() - }; - self.numer - .clone() - .map(|numer| numer * &denom) - .unwrap_or(denom) +impl Fraction { + pub fn evaluate(&mut self) { + assert!(self.inv); + assert!(self.eval.is_none()); + + self.eval = Some( + self.numer + .as_ref() + .map(|numer| numer.clone() * &self.denom) + .unwrap_or_else(|| self.denom.clone()), + ); } + + pub fn evaluated(&self) -> &T { + assert!(self.inv); + + self.eval.as_ref().unwrap() + } +} + +pub fn ilog2(value: usize) -> usize { + (usize::BITS - value.leading_zeros() - 1) as usize } -pub fn big_to_fe(big: BigUint) -> F { +pub fn modulus() -> BigUint { + fe_to_big(-F::one()) + 1usize +} + +pub fn fe_from_big(big: BigUint) -> F { let bytes = big.to_bytes_le(); let mut repr = F::Repr::default(); assert!(bytes.len() <= repr.as_ref().len()); @@ -212,10 +215,18 @@ pub fn big_to_fe(big: BigUint) -> F { F::from_repr(repr).unwrap() } +pub fn fe_to_big(fe: F) -> BigUint { + BigUint::from_bytes_le(fe.to_repr().as_ref()) +} + +pub fn fe_to_fe(fe: F1) -> F2 { + fe_from_big(fe_to_big(fe) % modulus::()) +} + pub fn fe_from_limbs( limbs: [F1; LIMBS], ) -> F2 { - big_to_fe( + fe_from_big( limbs .iter() .map(|limb| BigUint::from_bytes_le(limb.to_repr().as_ref())) @@ -234,8 +245,15 @@ pub fn fe_to_limbs> shift) & &mask)) + .map(move |shift| fe_from_big((&big >> shift) & &mask)) .collect_vec() .try_into() .unwrap() } + +pub fn powers(scalar: F) -> impl Iterator +where + for<'a> F: Mul<&'a F, Output = F> + One + Clone, +{ + iter::successors(Some(F::one()), move |power| Some(scalar.clone() * power)) +} diff --git a/src/util/hash.rs b/src/util/hash.rs new file mode 100644 index 00000000..17ede0b3 --- /dev/null +++ b/src/util/hash.rs @@ -0,0 +1,6 @@ +mod poseidon; + +pub use crate::util::hash::poseidon::Poseidon; + +#[cfg(feature = "loader_evm")] +pub use sha3::{Digest, Keccak256}; diff --git a/src/util/hash/poseidon.rs b/src/util/hash/poseidon.rs new file mode 100644 index 00000000..878b69ce --- /dev/null +++ b/src/util/hash/poseidon.rs @@ -0,0 +1,178 @@ +use crate::{ + loader::{LoadedScalar, ScalarLoader}, + util::{arithmetic::FieldExt, Itertools}, +}; +use poseidon::{self, SparseMDSMatrix, Spec}; +use std::{iter, marker::PhantomData, mem}; + +struct State { + inner: [L; T], + _marker: PhantomData, +} + +impl, const T: usize, const RATE: usize> State { + fn new(inner: [L; T]) -> Self { + Self { + inner, + _marker: PhantomData, + } + } + + fn loader(&self) -> &L::Loader { + self.inner[0].loader() + } + + fn power5_with_constant(value: &L, constant: &F) -> L { + value + .loader() + .sum_products_with_const(&[(value, &value.square().square())], *constant) + } + + fn sbox_full(&mut self, constants: &[F; T]) { + for (state, constant) in self.inner.iter_mut().zip(constants.iter()) { + *state = Self::power5_with_constant(state, constant); + } + } + + fn sbox_part(&mut self, constant: &F) { + self.inner[0] = Self::power5_with_constant(&self.inner[0], constant); + } + + fn absorb_with_pre_constants(&mut self, inputs: &[L], pre_constants: &[F; T]) { + assert!(inputs.len() < T); + + self.inner[0] = self + .loader() + .sum_with_const(&[&self.inner[0]], pre_constants[0]); + self.inner + .iter_mut() + .zip(pre_constants.iter()) + .skip(1) + .zip(inputs) + .for_each(|((state, constant), input)| { + *state = state.loader().sum_with_const(&[state, input], *constant); + }); + self.inner + .iter_mut() + .zip(pre_constants.iter()) + .skip(1 + inputs.len()) + .enumerate() + .for_each(|(idx, (state, constant))| { + *state = state.loader().sum_with_const( + &[state], + if idx == 0 { + F::one() + constant + } else { + *constant + }, + ); + }); + } + + fn apply_mds(&mut self, mds: &[[F; T]; T]) { + self.inner = mds + .iter() + .map(|row| { + self.loader() + .sum_with_coeff(&row.iter().cloned().zip(self.inner.iter()).collect_vec()) + }) + .collect_vec() + .try_into() + .unwrap(); + } + + fn apply_sparse_mds(&mut self, mds: &SparseMDSMatrix) { + self.inner = iter::once( + self.loader().sum_with_coeff( + &mds.row() + .iter() + .cloned() + .zip(self.inner.iter()) + .collect_vec(), + ), + ) + .chain( + mds.col_hat() + .iter() + .zip(self.inner.iter().skip(1)) + .map(|(coeff, state)| { + self.loader() + .sum_with_coeff(&[(*coeff, &self.inner[0]), (F::one(), state)]) + }), + ) + .collect_vec() + .try_into() + .unwrap(); + } +} + +pub struct Poseidon { + spec: Spec, + state: State, + buf: Vec, +} + +impl, const T: usize, const RATE: usize> Poseidon { + pub fn new(loader: L::Loader, r_f: usize, r_p: usize) -> Self { + Self { + spec: Spec::new(r_f, r_p), + state: State::new( + poseidon::State::default() + .words() + .map(|state| loader.load_const(&state)), + ), + buf: Vec::new(), + } + } + + pub fn update(&mut self, elements: &[L]) { + self.buf.extend_from_slice(elements); + } + + pub fn squeeze(&mut self) -> L { + let buf = mem::take(&mut self.buf); + let exact = buf.len() % RATE == 0; + + for chunk in buf.chunks(RATE) { + self.permutation(chunk); + } + if exact { + self.permutation(&[]); + } + + self.state.inner[1].clone() + } + + fn permutation(&mut self, inputs: &[L]) { + let r_f = self.spec.r_f() / 2; + let mds = self.spec.mds_matrices().mds().rows(); + let pre_sparse_mds = self.spec.mds_matrices().pre_sparse_mds().rows(); + let sparse_matrices = self.spec.mds_matrices().sparse_matrices(); + + // First half of the full rounds + let constants = self.spec.constants().start(); + self.state.absorb_with_pre_constants(inputs, &constants[0]); + for constants in constants.iter().skip(1).take(r_f - 1) { + self.state.sbox_full(constants); + self.state.apply_mds(&mds); + } + self.state.sbox_full(constants.last().unwrap()); + self.state.apply_mds(&pre_sparse_mds); + + // Partial rounds + let constants = self.spec.constants().partial(); + for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { + self.state.sbox_part(constant); + self.state.apply_sparse_mds(sparse_mds); + } + + // Second half of the full rounds + let constants = self.spec.constants().end(); + for constants in constants.iter() { + self.state.sbox_full(constants); + self.state.apply_mds(&mds); + } + self.state.sbox_full(&[F::zero(); T]); + self.state.apply_mds(&mds); + } +} diff --git a/src/util/msm.rs b/src/util/msm.rs new file mode 100644 index 00000000..a7a3d45d --- /dev/null +++ b/src/util/msm.rs @@ -0,0 +1,203 @@ +use crate::{ + loader::{LoadedEcPoint, Loader}, + util::arithmetic::CurveAffine, +}; +use std::{ + default::Default, + iter::{self, Sum}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; + +#[derive(Clone, Debug)] +pub struct Msm> { + constant: Option, + scalars: Vec, + bases: Vec, +} + +impl Default for Msm +where + C: CurveAffine, + L: Loader, +{ + fn default() -> Self { + Self { + constant: None, + scalars: Vec::new(), + bases: Vec::new(), + } + } +} + +impl Msm +where + C: CurveAffine, + L: Loader, +{ + pub fn constant(constant: L::LoadedScalar) -> Self { + Msm { + constant: Some(constant), + ..Default::default() + } + } + + pub fn base(base: L::LoadedEcPoint) -> Self { + let one = base.loader().load_one(); + Msm { + scalars: vec![one], + bases: vec![base], + ..Default::default() + } + } + + pub(crate) fn size(&self) -> usize { + self.bases.len() + } + + pub(crate) fn split(mut self) -> (Self, Option) { + let constant = self.constant.take(); + (self, constant) + } + + pub(crate) fn try_into_constant(self) -> Option { + self.bases.is_empty().then(|| self.constant.unwrap()) + } + + pub fn evaluate(self, gen: Option) -> L::LoadedEcPoint { + let gen = gen.map(|gen| { + self.bases + .first() + .unwrap() + .loader() + .ec_point_load_const(&gen) + }); + L::LoadedEcPoint::multi_scalar_multiplication( + iter::empty() + .chain(self.constant.map(|constant| (constant, gen.unwrap()))) + .chain(self.scalars.into_iter().zip(self.bases.into_iter())), + ) + } + + pub fn scale(&mut self, factor: &L::LoadedScalar) { + if let Some(constant) = self.constant.as_mut() { + *constant *= factor; + } + for scalar in self.scalars.iter_mut() { + *scalar *= factor + } + } + + pub fn push(&mut self, scalar: L::LoadedScalar, base: L::LoadedEcPoint) { + if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) { + self.scalars[pos] += scalar; + } else { + self.scalars.push(scalar); + self.bases.push(base); + } + } + + pub fn extend(&mut self, mut other: Self) { + match (self.constant.as_mut(), other.constant.as_ref()) { + (Some(lhs), Some(rhs)) => *lhs += rhs, + (None, Some(_)) => self.constant = other.constant.take(), + _ => {} + }; + for (scalar, base) in other.scalars.into_iter().zip(other.bases) { + self.push(scalar, base); + } + } +} + +impl Add> for Msm +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm; + + fn add(mut self, rhs: Msm) -> Self::Output { + self.extend(rhs); + self + } +} + +impl AddAssign> for Msm +where + C: CurveAffine, + L: Loader, +{ + fn add_assign(&mut self, rhs: Msm) { + self.extend(rhs); + } +} + +impl Sub> for Msm +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm; + + fn sub(mut self, rhs: Msm) -> Self::Output { + self.extend(-rhs); + self + } +} + +impl SubAssign> for Msm +where + C: CurveAffine, + L: Loader, +{ + fn sub_assign(&mut self, rhs: Msm) { + self.extend(-rhs); + } +} + +impl Mul<&L::LoadedScalar> for Msm +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm; + + fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { + self.scale(rhs); + self + } +} + +impl MulAssign<&L::LoadedScalar> for Msm +where + C: CurveAffine, + L: Loader, +{ + fn mul_assign(&mut self, rhs: &L::LoadedScalar) { + self.scale(rhs); + } +} + +impl Neg for Msm +where + C: CurveAffine, + L: Loader, +{ + type Output = Msm; + fn neg(mut self) -> Msm { + self.constant = self.constant.map(|constant| -constant); + for scalar in self.scalars.iter_mut() { + *scalar = -scalar.clone(); + } + self + } +} + +impl Sum for Msm +where + C: CurveAffine, + L: Loader, +{ + fn sum>(iter: I) -> Self { + iter.reduce(|acc, item| acc + item).unwrap_or_default() + } +} diff --git a/src/util/expression.rs b/src/util/protocol.rs similarity index 65% rename from src/util/expression.rs rename to src/util/protocol.rs index 41b04ded..bae363ff 100644 --- a/src/util/expression.rs +++ b/src/util/protocol.rs @@ -1,7 +1,12 @@ use crate::{ loader::{LoadedScalar, Loader}, - util::{Curve, Domain, Field, Fraction, Itertools, Rotation}, + util::{ + arithmetic::{CurveAffine, Domain, Field, Fraction, Rotation}, + Itertools, + }, }; +use num_integer::Integer; +use num_traits::One; use std::{ cmp::max, collections::{BTreeMap, BTreeSet}, @@ -19,10 +24,11 @@ pub enum CommonPolynomial { #[derive(Clone, Debug)] pub struct CommonPolynomialEvaluation where - C: Curve, + C: CurveAffine, L: Loader, { zn: L::LoadedScalar, + zn_minus_one: L::LoadedScalar, zn_minus_one_inv: Fraction, identity: L::LoadedScalar, lagrange: BTreeMap>, @@ -30,28 +36,29 @@ where impl CommonPolynomialEvaluation where - C: Curve, + C: CurveAffine, L: Loader, { pub fn new( domain: &Domain, - loader: &L, langranges: impl IntoIterator, z: &L::LoadedScalar, ) -> Self { + let loader = z.loader(); + let zn = z.pow_const(domain.n as u64); let langranges = langranges.into_iter().sorted().dedup().collect_vec(); let one = loader.load_one(); let zn_minus_one = zn.clone() - one; + let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone()); + let n_inv = loader.load_const(&domain.n_inv); let numer = zn_minus_one.clone() * n_inv; - let omegas = langranges .iter() .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::one(), Rotation(i)))) .collect_vec(); - let lagrange_evals = omegas .iter() .map(|omega| Fraction::new(numer.clone() * omega, z.clone() - omega)) @@ -59,24 +66,29 @@ where Self { zn, - zn_minus_one_inv: Fraction::one_over(zn_minus_one), + zn_minus_one, + zn_minus_one_inv, identity: z.clone(), lagrange: langranges.into_iter().zip(lagrange_evals).collect(), } } - pub fn zn(&self) -> L::LoadedScalar { - self.zn.clone() + pub fn zn(&self) -> &L::LoadedScalar { + &self.zn + } + + pub fn zn_minus_one(&self) -> &L::LoadedScalar { + &self.zn_minus_one } - pub fn zn_minus_one_inv(&self) -> L::LoadedScalar { - self.zn_minus_one_inv.evaluate() + pub fn zn_minus_one_inv(&self) -> &L::LoadedScalar { + self.zn_minus_one_inv.evaluated() } - pub fn get(&self, poly: CommonPolynomial) -> L::LoadedScalar { + pub fn get(&self, poly: CommonPolynomial) -> &L::LoadedScalar { match poly { - CommonPolynomial::Identity => self.identity.clone(), - CommonPolynomial::Lagrange(i) => self.lagrange.get(&i).unwrap().evaluate(), + CommonPolynomial::Identity => &self.identity, + CommonPolynomial::Lagrange(i) => self.lagrange.get(&i).unwrap().evaluated(), } } @@ -87,6 +99,26 @@ where .chain(iter::once(self.zn_minus_one_inv.denom_mut())) .flatten() } + + pub fn evaluate(&mut self) { + self.lagrange + .iter_mut() + .map(|(_, value)| value) + .chain(iter::once(&mut self.zn_minus_one_inv)) + .for_each(Fraction::evaluate) + } +} + +#[derive(Clone, Debug)] +pub struct QuotientPolynomial { + pub chunk_degree: usize, + pub numerator: Expression, +} + +impl QuotientPolynomial { + pub fn num_chunk(&self) -> usize { + Integer::div_ceil(&(self.numerator.degree() - 1), &self.chunk_degree) + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -114,10 +146,11 @@ pub enum Expression { Sum(Box>, Box>), Product(Box>, Box>), Scaled(Box>, F), + DistributePowers(Vec>, Box>), } impl Expression { - pub fn evaluate( + pub fn evaluate( &self, constant: &impl Fn(F) -> T, common_poly: &impl Fn(CommonPolynomial) -> T, @@ -128,83 +161,53 @@ impl Expression { product: &impl Fn(T, T) -> T, scaled: &impl Fn(T, F) -> T, ) -> T { + let evaluate = |expr: &Expression| { + expr.evaluate( + constant, + common_poly, + poly, + challenge, + negated, + sum, + product, + scaled, + ) + }; match self { Expression::Constant(scalar) => constant(scalar.clone()), Expression::CommonPolynomial(poly) => common_poly(*poly), Expression::Polynomial(query) => poly(*query), Expression::Challenge(index) => challenge(*index), Expression::Negated(a) => { - let a = a.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); + let a = evaluate(a); negated(a) } Expression::Sum(a, b) => { - let a = a.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); - let b = b.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); + let a = evaluate(a); + let b = evaluate(b); sum(a, b) } Expression::Product(a, b) => { - let a = a.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); - let b = b.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); + let a = evaluate(a); + let b = evaluate(b); product(a, b) } Expression::Scaled(a, scalar) => { - let a = a.evaluate( - constant, - common_poly, - poly, - challenge, - negated, - sum, - product, - scaled, - ); + let a = evaluate(a); scaled(a, scalar.clone()) } + Expression::DistributePowers(exprs, scalar) => { + assert!(!exprs.is_empty()); + if exprs.len() == 1 { + return evaluate(exprs.first().unwrap()); + } + let mut exprs = exprs.iter(); + let first = evaluate(exprs.next().unwrap()); + let scalar = evaluate(scalar); + exprs.fold(first, |acc, expr| { + sum(product(acc, scalar.clone()), evaluate(expr)) + }) + } } } @@ -218,6 +221,12 @@ impl Expression { Expression::Sum(a, b) => max(a.degree(), b.degree()), Expression::Product(a, b) => a.degree() + b.degree(), Expression::Scaled(a, _) => a.degree(), + Expression::DistributePowers(a, b) => a + .iter() + .chain(Some(b.as_ref())) + .map(Self::degree) + .max() + .unwrap_or_default(), } } @@ -237,6 +246,20 @@ impl Expression { ) .unwrap_or_default() } + + pub fn used_query(&self) -> BTreeSet { + self.evaluate( + &|_| None, + &|_| None, + &|query| Some(BTreeSet::from_iter([query])), + &|_| None, + &|a| a, + &merge_left_right, + &merge_left_right, + &|a, _| a, + ) + .unwrap_or_default() + } } impl From for Expression { @@ -306,6 +329,12 @@ impl Sum for Expression { } } +impl One for Expression { + fn one() -> Self { + Expression::Constant(F::one()) + } +} + fn merge_left_right(a: Option>, b: Option>) -> Option> { match (a, b) { (Some(a), None) | (None, Some(a)) => Some(a), @@ -316,3 +345,21 @@ fn merge_left_right(a: Option>, b: Option>) -> O _ => None, } } + +#[derive(Clone, Debug)] +pub enum LinearizationStrategy { + /// Older linearization strategy of GWC19, which has linearization + /// polynomial that doesn't evaluate to 0, and requires prover to send extra + /// evaluation of it to verifier. + WithoutConstant, + /// Current linearization strategy of GWC19, which has linearization + /// polynomial that evaluate to 0 by subtracting product of vanishing and + /// quotient polynomials. + MinusVanishingTimesQuotient, +} + +#[derive(Clone, Debug, Default)] +pub struct InstanceCommittingKey { + pub bases: Vec, + pub constant: Option, +} diff --git a/src/util/transcript.rs b/src/util/transcript.rs index a42d5e70..3337324d 100644 --- a/src/util/transcript.rs +++ b/src/util/transcript.rs @@ -1,13 +1,15 @@ use crate::{ - loader::Loader, - {util::Curve, Error}, + loader::{native::NativeLoader, Loader}, + {util::arithmetic::CurveAffine, Error}, }; pub trait Transcript where - C: Curve, + C: CurveAffine, L: Loader, { + fn loader(&self) -> &L; + fn squeeze_challenge(&mut self) -> L::LoadedScalar; fn squeeze_n_challenges(&mut self, n: usize) -> Vec { @@ -21,7 +23,7 @@ where pub trait TranscriptRead: Transcript where - C: Curve, + C: CurveAffine, L: Loader, { fn read_scalar(&mut self) -> Result; @@ -36,3 +38,9 @@ where (0..n).map(|_| self.read_ec_point()).collect() } } + +pub trait TranscriptWrite: Transcript { + fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error>; + + fn write_ec_point(&mut self, ec_point: C) -> Result<(), Error>; +} diff --git a/src/verifier.rs b/src/verifier.rs new file mode 100644 index 00000000..07529603 --- /dev/null +++ b/src/verifier.rs @@ -0,0 +1,51 @@ +use crate::{ + loader::Loader, + pcs::{Decider, MultiOpenScheme}, + util::{arithmetic::CurveAffine, transcript::TranscriptRead}, + Error, Protocol, +}; +use std::fmt::Debug; + +mod plonk; + +pub use plonk::{Plonk, PlonkProof}; + +pub trait PlonkVerifier +where + C: CurveAffine, + L: Loader, + MOS: MultiOpenScheme, +{ + type Proof: Clone + Debug; + + fn read_proof( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead; + + fn succinct_verify( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + proof: &Self::Proof, + ) -> Result, Error>; + + fn verify( + svk: &MOS::SuccinctVerifyingKey, + dk: &MOS::DecidingKey, + protocol: &Protocol, + instances: &[Vec], + proof: &Self::Proof, + ) -> Result + where + MOS: Decider, + { + let accumulators = Self::succinct_verify(svk, protocol, instances, proof)?; + let output = MOS::decide_all(dk, accumulators); + Ok(output) + } +} diff --git a/src/verifier/plonk.rs b/src/verifier/plonk.rs new file mode 100644 index 00000000..9e08e585 --- /dev/null +++ b/src/verifier/plonk.rs @@ -0,0 +1,464 @@ +use crate::{ + cost::{Cost, CostEstimation}, + loader::{native::NativeLoader, LoadedScalar, Loader}, + pcs::{self, AccumulatorEncoding, MultiOpenScheme}, + util::{ + arithmetic::{CurveAffine, Field, Rotation}, + msm::Msm, + protocol::{ + CommonPolynomial::Lagrange, CommonPolynomialEvaluation, LinearizationStrategy, Query, + }, + transcript::TranscriptRead, + Itertools, + }, + verifier::PlonkVerifier, + Error, Protocol, +}; +use std::{collections::HashMap, iter, marker::PhantomData}; + +pub struct Plonk(PhantomData<(MOS, AE)>); + +impl PlonkVerifier for Plonk +where + C: CurveAffine, + L: Loader, + MOS: MultiOpenScheme, + AE: AccumulatorEncoding, +{ + type Proof = PlonkProof; + + fn read_proof( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + PlonkProof::read::(svk, protocol, instances, transcript) + } + + fn succinct_verify( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + proof: &Self::Proof, + ) -> Result, Error> { + let common_poly_eval = { + let mut common_poly_eval = CommonPolynomialEvaluation::new( + &protocol.domain, + langranges(protocol, instances), + &proof.z, + ); + + L::LoadedScalar::batch_invert(common_poly_eval.denoms()); + common_poly_eval.evaluate(); + + common_poly_eval + }; + + let mut evaluations = proof.evaluations(protocol, instances, &common_poly_eval)?; + let commitments = proof.commitments(protocol, &common_poly_eval, &mut evaluations)?; + let queries = proof.queries(protocol, evaluations); + + let accumulator = MOS::succinct_verify(svk, &commitments, &proof.z, &queries, &proof.pcs)?; + + let accumulators = iter::empty() + .chain(Some(accumulator)) + .chain(proof.old_accumulators.iter().cloned()) + .collect(); + + Ok(accumulators) + } +} + +#[derive(Clone, Debug)] +pub struct PlonkProof +where + C: CurveAffine, + L: Loader, + MOS: MultiOpenScheme, +{ + pub committed_instances: Option>, + pub witnesses: Vec, + pub challenges: Vec, + pub quotients: Vec, + pub z: L::LoadedScalar, + pub evaluations: Vec, + pub pcs: MOS::Proof, + pub old_accumulators: Vec, +} + +impl PlonkProof +where + C: CurveAffine, + L: Loader, + MOS: MultiOpenScheme, +{ + fn read( + svk: &MOS::SuccinctVerifyingKey, + protocol: &Protocol, + instances: &[Vec], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + AE: AccumulatorEncoding, + { + let loader = transcript.loader(); + if let Some(transcript_initial_state) = &protocol.transcript_initial_state { + transcript.common_scalar(&loader.load_const(transcript_initial_state))?; + } + + if protocol.num_instance + != instances + .iter() + .map(|instances| instances.len()) + .collect_vec() + { + return Err(Error::InvalidInstances); + } + + let committed_instances = if let Some(ick) = &protocol.instance_committing_key { + let loader = transcript.loader(); + let bases = ick + .bases + .iter() + .map(|value| loader.ec_point_load_const(value)) + .collect_vec(); + let constant = ick + .constant + .as_ref() + .map(|value| loader.ec_point_load_const(value)); + + let committed_instances = instances + .iter() + .map(|instances| { + instances + .iter() + .zip(bases.iter()) + .map(|(scalar, base)| Msm::::base(base.clone()) * scalar) + .chain(constant.clone().map(|constant| Msm::base(constant))) + .sum::>() + .evaluate(None) + }) + .collect_vec(); + for committed_instance in committed_instances.iter() { + transcript.common_ec_point(committed_instance)?; + } + + Some(committed_instances) + } else { + for instances in instances.iter() { + for instance in instances.iter() { + transcript.common_scalar(instance)?; + } + } + + None + }; + + let (witnesses, challenges) = { + let (witnesses, challenges) = protocol + .num_witness + .iter() + .zip(protocol.num_challenge.iter()) + .map(|(&n, &m)| { + Ok(( + transcript.read_n_ec_points(n)?, + transcript.squeeze_n_challenges(m), + )) + }) + .collect::, Error>>()? + .into_iter() + .unzip::<_, _, Vec<_>, Vec<_>>(); + + ( + witnesses.into_iter().flatten().collect_vec(), + challenges.into_iter().flatten().collect_vec(), + ) + }; + + let quotients = transcript.read_n_ec_points(protocol.quotient.num_chunk())?; + + let z = transcript.squeeze_challenge(); + let evaluations = transcript.read_n_scalars(protocol.evaluations.len())?; + + let pcs = MOS::read_proof(svk, &Self::empty_queries(protocol), transcript)?; + + let old_accumulators = protocol + .accumulator_indices + .iter() + .map(|accumulator_indices| { + accumulator_indices + .iter() + .map(|&(i, j)| instances[i][j].clone()) + .collect() + }) + .map(AE::from_repr) + .collect::, _>>()?; + + Ok(Self { + committed_instances, + witnesses, + challenges, + quotients, + z, + evaluations, + pcs, + old_accumulators, + }) + } + + fn empty_queries(protocol: &Protocol) -> Vec> { + protocol + .queries + .iter() + .map(|query| pcs::Query { + poly: query.poly, + shift: protocol + .domain + .rotate_scalar(C::Scalar::one(), query.rotation), + eval: (), + }) + .collect() + } + + fn queries( + &self, + protocol: &Protocol, + mut evaluations: HashMap, + ) -> Vec> { + Self::empty_queries(protocol) + .into_iter() + .zip( + protocol + .queries + .iter() + .map(|query| evaluations.remove(query).unwrap()), + ) + .map(|(query, eval)| query.with_evaluation(eval)) + .collect() + } + + fn commitments( + &self, + protocol: &Protocol, + common_poly_eval: &CommonPolynomialEvaluation, + evaluations: &mut HashMap, + ) -> Result>, Error> { + let loader = common_poly_eval.zn().loader(); + let mut commitments = iter::empty() + .chain( + protocol + .preprocessed + .iter() + .map(|value| Msm::base(loader.ec_point_load_const(value))), + ) + .chain( + self.committed_instances + .clone() + .map(|committed_instances| { + committed_instances.into_iter().map(Msm::base).collect_vec() + }) + .unwrap_or_else(|| { + iter::repeat_with(Default::default) + .take(protocol.num_instance.len()) + .collect_vec() + }), + ) + .chain(self.witnesses.iter().cloned().map(Msm::base)) + .collect_vec(); + + let numerator = protocol.quotient.numerator.evaluate( + &|scalar| Ok(Msm::constant(loader.load_const(&scalar))), + &|poly| Ok(Msm::constant(common_poly_eval.get(poly).clone())), + &|query| { + evaluations + .get(&query) + .cloned() + .map(Msm::constant) + .or_else(|| { + (query.rotation == Rotation::cur()) + .then(|| commitments.get(query.poly).cloned()) + .flatten() + }) + .ok_or(Error::InvalidQuery(query)) + }, + &|index| { + self.challenges + .get(index) + .cloned() + .map(Msm::constant) + .ok_or(Error::InvalidChallenge(index)) + }, + &|a| Ok(-a?), + &|a, b| Ok(a? + b?), + &|a, b| { + let (a, b) = (a?, b?); + match (a.size(), b.size()) { + (0, _) => Ok(b * &a.try_into_constant().unwrap()), + (_, 0) => Ok(a * &b.try_into_constant().unwrap()), + (_, _) => Err(Error::InvalidLinearization), + } + }, + &|a, scalar| Ok(a? * &loader.load_const(&scalar)), + )?; + + let quotient_query = Query::new( + protocol.preprocessed.len() + protocol.num_instance.len() + self.witnesses.len(), + Rotation::cur(), + ); + let quotient = common_poly_eval + .zn() + .pow_const(protocol.quotient.chunk_degree as u64) + .powers(self.quotients.len()) + .into_iter() + .zip(self.quotients.iter().cloned().map(Msm::base)) + .map(|(coeff, chunk)| chunk * &coeff) + .sum::>(); + match protocol.linearization { + Some(LinearizationStrategy::WithoutConstant) => { + let linearization_query = Query::new(quotient_query.poly + 1, Rotation::cur()); + let (msm, constant) = numerator.split(); + commitments.push(quotient); + commitments.push(msm); + evaluations.insert( + quotient_query, + (constant.unwrap_or_else(|| loader.load_zero()) + + evaluations.get(&linearization_query).unwrap()) + * common_poly_eval.zn_minus_one_inv(), + ); + } + Some(LinearizationStrategy::MinusVanishingTimesQuotient) => { + let (msm, constant) = + (numerator - quotient * common_poly_eval.zn_minus_one()).split(); + commitments.push(msm); + evaluations.insert( + quotient_query, + constant.unwrap_or_else(|| loader.load_zero()), + ); + } + None => { + commitments.push(quotient); + evaluations.insert( + quotient_query, + numerator + .try_into_constant() + .ok_or(Error::InvalidLinearization)? + * common_poly_eval.zn_minus_one_inv(), + ); + } + } + + Ok(commitments) + } + + fn evaluations( + &self, + protocol: &Protocol, + instances: &[Vec], + common_poly_eval: &CommonPolynomialEvaluation, + ) -> Result, Error> { + let loader = common_poly_eval.zn().loader(); + let instance_evals = protocol.instance_committing_key.is_none().then(|| { + let offset = protocol.preprocessed.len(); + let queries = { + let range = offset..offset + protocol.num_instance.len(); + protocol + .quotient + .numerator + .used_query() + .into_iter() + .filter(move |query| range.contains(&query.poly)) + }; + queries + .map(move |query| { + let instances = instances[query.poly - offset].iter(); + let l_i_minus_r = (-query.rotation.0..) + .map(|i_minus_r| common_poly_eval.get(Lagrange(i_minus_r))); + let eval = loader.sum_products(&instances.zip(l_i_minus_r).collect_vec()); + (query, eval) + }) + .collect_vec() + }); + + let evals = iter::empty() + .chain(instance_evals.into_iter().flatten()) + .chain( + protocol + .evaluations + .iter() + .cloned() + .zip(self.evaluations.iter().cloned()), + ) + .collect(); + + Ok(evals) + } +} + +impl CostEstimation<(C, MOS)> for Plonk +where + C: CurveAffine, + MOS: MultiOpenScheme + CostEstimation>>, +{ + type Input = Protocol; + + fn estimate_cost(protocol: &Protocol) -> Cost { + let plonk_cost = { + let num_accumulator = protocol.accumulator_indices.len(); + let num_instance = protocol.num_instance.iter().sum(); + let num_commitment = + protocol.num_witness.iter().sum::() + protocol.quotient.num_chunk(); + let num_evaluation = protocol.evaluations.len(); + let num_msm = protocol.preprocessed.len() + num_commitment + 1 + 2 * num_accumulator; + Cost::new(num_instance, num_commitment, num_evaluation, num_msm) + }; + let pcs_cost = { + let queries = PlonkProof::::empty_queries(protocol); + MOS::estimate_cost(&queries) + }; + plonk_cost + pcs_cost + } +} + +fn langranges(protocol: &Protocol, instances: &[Vec]) -> impl IntoIterator +where + C: CurveAffine, +{ + let instance_eval_lagrange = protocol.instance_committing_key.is_none().then(|| { + let queries = { + let offset = protocol.preprocessed.len(); + let range = offset..offset + protocol.num_instance.len(); + protocol + .quotient + .numerator + .used_query() + .into_iter() + .filter(move |query| range.contains(&query.poly)) + }; + let (min_rotation, max_rotation) = queries.fold((0, 0), |(min, max), query| { + if query.rotation.0 < min { + (query.rotation.0, max) + } else if query.rotation.0 > max { + (min, query.rotation.0) + } else { + (min, max) + } + }); + let max_instance_len = instances + .iter() + .map(|instance| instance.len()) + .max() + .unwrap_or_default(); + -max_rotation..max_instance_len as i32 + min_rotation.abs() + }); + protocol + .quotient + .numerator + .used_langrange() + .into_iter() + .chain(instance_eval_lagrange.into_iter().flatten()) +} From 0fe63f86d2eebfc5d784119e7c15a5ee8384014e Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Mon, 24 Oct 2022 18:56:32 -0700 Subject: [PATCH 05/28] chore: update dependencies with latest halo2_proofs --- Cargo.toml | 7 ++- src/loader/halo2/loader.rs | 6 +- src/system/halo2/aggregation.rs | 106 +++++++++++++++++++++++++++++++- 3 files changed, 112 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6df8b804..e568089a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,10 +16,10 @@ serde_json = "1.0" hex = "0.4.3" ark-std = { version = "0.3", features = ["print-trace"] } -halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.2.1", package = "halo2curves" } +halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } # system_halo2 -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_09_10", optional = true } +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_10_22", optional = true } # loader_evm ethereum_types = { package = "ethereum-types", version = "0.13.1", default-features = false, features = ["std"], optional = true } @@ -43,11 +43,12 @@ zkevm_circuit_benchmarks = {git = "https://github.com/privacy-scaling-exploratio zkevm_circuits = {git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", package = "zkevm-circuits" } [features] -default = ["loader_evm", "loader_halo2", "system_halo2", "display"] +default = ["loader_evm", "loader_halo2", "system_halo2", "display", "serialize"] loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:foundry_evm"] loader_halo2 = ["dep:halo2_proofs", "dep:halo2_base", "halo2_ecc", "dep:poseidon"] system_halo2 = ["dep:halo2_proofs"] display = ["halo2_ecc/display"] +serialize = [] sanity_check = [] [patch."https://github.com/privacy-scaling-explorations/halo2"] diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index 04b5a2f4..d5bcfb26 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -86,7 +86,7 @@ where self.ctx.borrow() } - pub(crate) fn ctx_mut(&self) -> impl DerefMut> + '_ { + pub fn ctx_mut(&self) -> impl DerefMut> + '_ { self.ctx.borrow_mut() } @@ -139,6 +139,10 @@ where Scalar { loader: self.clone(), index, value } } + pub fn scalar_from_assigned(self: &Rc, assigned: AssignedValue) -> Scalar<'a, 'b, C> { + self.scalar(Value::Assigned(assigned)) + } + pub fn ec_point(self: &Rc, assigned: AssignedEcPoint) -> EcPoint<'a, 'b, C> { let index = *self.num_ec_point.borrow(); *self.num_ec_point.borrow_mut() += 1; diff --git a/src/system/halo2/aggregation.rs b/src/system/halo2/aggregation.rs index 67ffc809..5b50bceb 100644 --- a/src/system/halo2/aggregation.rs +++ b/src/system/halo2/aggregation.rs @@ -1,6 +1,6 @@ use super::{BITS, LIMBS}; use crate::{ - loader::{self, native::NativeLoader}, + loader::{self, native::NativeLoader, Loader}, pcs::{ kzg::{ Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgAsProvingKey, KzgAsVerifyingKey, @@ -124,6 +124,14 @@ impl SnarkWitness { } } + pub fn protocol(&self) -> &Protocol { + &self.protocol + } + + pub fn instances(&self) -> &[Vec>] { + &self.instances + } + pub fn proof(&self) -> Value<&[u8]> { self.proof.as_ref().map(Vec::as_slice) } @@ -189,6 +197,98 @@ pub fn aggregate<'a, 'b>( .collect_vec() } +pub fn recursive_aggregate<'a, 'b>( + svk: &Svk, + loader: &Rc>, + snarks: &[SnarkWitness], + recursive_snark: &SnarkWitness, + as_vk: &AsVk, + as_proof: Value<&'_ [u8]>, + use_dummy: AssignedValue, +) -> (Vec>, Vec>>) { + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() + }) + .collect_vec() + }; + + let mut assigned_instances = vec![]; + let mut accumulators = snarks + .iter() + .flat_map(|snark| { + let instances = assign_instances(&snark.instances); + assigned_instances.push( + instances + .iter() + .flat_map(|instance| instance.iter().map(|scalar| scalar.assigned())) + .collect_vec(), + ); + let mut transcript = + PoseidonTranscript::, _, _>::new(loader, snark.proof()); + let proof = + Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + }) + .collect_vec(); + + let use_dummy = loader.scalar_from_assigned(use_dummy); + + let prev_instances = assign_instances(&recursive_snark.instances); + let mut accs = { + let mut transcript = + PoseidonTranscript::, _, _>::new(loader, recursive_snark.proof()); + let proof = + Plonk::read_proof(svk, &recursive_snark.protocol, &prev_instances, &mut transcript) + .unwrap(); + let mut accs = Plonk::succinct_verify_or_dummy( + svk, + &recursive_snark.protocol, + &prev_instances, + &proof, + &use_dummy, + ) + .unwrap(); + for acc in accs.iter_mut() { + (*acc).lhs = + loader.ec_point_select(&accumulators[0].lhs, &acc.lhs, &use_dummy).unwrap(); + (*acc).rhs = + loader.ec_point_select(&accumulators[0].rhs, &acc.rhs, &use_dummy).unwrap(); + } + accs + }; + accumulators.append(&mut accs); + + let KzgAccumulator { lhs, rhs } = { + let mut transcript = PoseidonTranscript::, _, _>::new(loader, as_proof); + let proof = As::read_proof(as_vk, &accumulators, &mut transcript).unwrap(); + As::verify(as_vk, &accumulators, &proof).unwrap() + }; + + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + + let mut new_instances = prev_instances + .iter() + .flat_map(|instance| instance.iter().map(|scalar| scalar.assigned())) + .collect_vec(); + for (i, acc_limb) in lhs + .x + .truncation + .limbs + .iter() + .chain(lhs.y.truncation.limbs.iter()) + .chain(rhs.x.truncation.limbs.iter()) + .chain(rhs.y.truncation.limbs.iter()) + .enumerate() + { + new_instances[i] = acc_limb.clone(); + } + (new_instances, assigned_instances) +} + #[derive(Clone)] pub struct AggregationCircuit { svk: Svk, @@ -492,9 +592,9 @@ pub trait TargetCircuit { } // this is a toggle that should match the fork of halo2_proofs you are using -// the current default in PSE/main is `true`, while there is a PR to make it `false`: +// the current default in PSE/main is `false`, before 2022_10_22 it was `false`: // see https://github.com/privacy-scaling-explorations/halo2/pull/96/files -pub const KZG_QUERY_INSTANCE: bool = true; +pub const KZG_QUERY_INSTANCE: bool = false; pub fn create_snark_shplonk( params: &ParamsKZG, From 4e8d9c235aa0746f29b575599dc8e076c9149b2f Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Tue, 25 Oct 2022 00:07:13 -0700 Subject: [PATCH 06/28] add `serialize` feature to turn on/off vkey/pkey write --- Cargo.toml | 4 ++-- src/system/halo2/aggregation.rs | 25 +++++++++++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e568089a..51935591 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ foundry_evm = { git = "https://github.com/jonathanpwang/foundry", package = "fou # loader_halo2 halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_base", default-features = false, optional = true } halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_ecc", default-features = false, optional = true } -poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", branch = "padding", optional = true } +poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true } [dev-dependencies] paste = "1.0.7" @@ -53,7 +53,7 @@ sanity_check = [] [patch."https://github.com/privacy-scaling-explorations/halo2"] halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/add-serialization", package = "halo2_proofs" } - + [[example]] name = "evm-verifier" required-features = ["loader_evm", "system_halo2"] diff --git a/src/system/halo2/aggregation.rs b/src/system/halo2/aggregation.rs index 5b50bceb..f9d2af82 100644 --- a/src/system/halo2/aggregation.rs +++ b/src/system/halo2/aggregation.rs @@ -481,12 +481,14 @@ pub fn gen_vk>( name: &str, ) -> VerifyingKey { let path = format!("./data/{}_{}.vkey", name, params.k()); + #[cfg(feature = "serialize")] match File::open(path.as_str()) { Ok(f) => { - println!("Reading vkey from {}", path); + let read_time = start_timer!(|| format!("Reading vkey from {}", path)); let mut bufreader = BufReader::new(f); let vk = VerifyingKey::read::<_, ConcreteCircuit>(&mut bufreader, params) .expect("Reading vkey should not fail"); + end_timer!(read_time); vk } Err(_) => { @@ -499,6 +501,13 @@ pub fn gen_vk>( vk } } + #[cfg(not(feature = "serialize"))] + { + let vk_time = start_timer!(|| "vkey"); + let vk = keygen_vk(params, circuit).unwrap(); + end_timer!(vk_time); + vk + } } pub fn gen_pk>( @@ -507,12 +516,14 @@ pub fn gen_pk>( name: &str, ) -> ProvingKey { let path = format!("./data/{}_{}.pkey", name, params.k()); + #[cfg(feature = "serialize")] match File::open(path.as_str()) { Ok(f) => { - println!("Reading pkey from {}", path); + let read_time = start_timer!(|| format!("Reading pkey from {}", path)); let mut bufreader = BufReader::new(f); let pk = ProvingKey::read::<_, ConcreteCircuit>(&mut bufreader, params) .expect("Reading pkey should not fail"); + end_timer!(read_time); pk } Err(_) => { @@ -526,6 +537,14 @@ pub fn gen_pk>( pk } } + #[cfg(not(feature = "serialize"))] + { + let vk = gen_vk::(params, circuit, name); + let pk_time = start_timer!(|| "pkey"); + let pk = keygen_pk(params, vk, circuit).unwrap(); + end_timer!(pk_time); + pk + } } pub fn read_bytes(path: &str) -> Vec { @@ -633,9 +652,11 @@ pub fn create_snark_shplonk( && Path::new(path.as_str()).exists() && cached_instances.unwrap() == instances { + let proof_time = start_timer!(|| "read proof"); let mut file = File::open(path.as_str()).unwrap(); let mut buf = vec![]; file.read_to_end(&mut buf).unwrap(); + end_timer!(proof_time); buf } else { let proof_time = start_timer!(|| "create proof"); From 577f2837eaf60f6e599af20d4f1db9ce5d627d6a Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Tue, 25 Oct 2022 10:45:58 -0700 Subject: [PATCH 07/28] feat: change `name` to function instead of const in `TargetCircuit` --- examples/evm-verifier-with-accumulator.rs | 4 +++- src/system/halo2/aggregation.rs | 29 +++++++++++++++++++---- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 1bf58f28..eb4f19c6 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -306,9 +306,11 @@ pub fn load_verify_circuit_degree() -> u32 { impl TargetCircuit for StandardPlonk { const N_PROOFS: usize = 1; - const NAME: &'static str = "standard_plonk"; type Circuit = Self; + fn name() -> String { + "standard_plonk".to_string() + } } fn main() { diff --git a/src/system/halo2/aggregation.rs b/src/system/halo2/aggregation.rs index f9d2af82..17197063 100644 --- a/src/system/halo2/aggregation.rs +++ b/src/system/halo2/aggregation.rs @@ -605,9 +605,10 @@ pub fn write_instances(instances: &Vec>>, path: &str) { pub trait TargetCircuit { const N_PROOFS: usize; - const NAME: &'static str; type Circuit: Circuit; + + fn name() -> String; } // this is a toggle that should match the fork of halo2_proofs you are using @@ -621,7 +622,7 @@ pub fn create_snark_shplonk( instances: Vec>>, // instances[i][j][..] is the i-th circuit's j-th instance column accumulator_indices: Option>, ) -> Snark { - println!("CREATING SNARK FOR: {}", T::NAME); + println!("CREATING SNARK FOR: {}", T::name()); let config = if let Some(accumulator_indices) = accumulator_indices { Config::kzg(KZG_QUERY_INSTANCE) .set_zk(true) @@ -631,7 +632,7 @@ pub fn create_snark_shplonk( Config::kzg(KZG_QUERY_INSTANCE).set_zk(true).with_num_proof(T::N_PROOFS) }; - let pk = gen_pk(params, &circuits[0], T::NAME); + let pk = gen_pk(params, &circuits[0], T::name().as_str()); // num_instance[i] is length of the i-th instance columns in circuit 0 (all circuits should have same shape of instances) let num_instance = instances[0].iter().map(|instance_column| instance_column.len()).collect(); let protocol = compile(params, pk.get_vk(), config.with_num_instance(num_instance)); @@ -645,9 +646,10 @@ pub fn create_snark_shplonk( // TODO: need to cache the instances as well! let proof = { - let path = format!("./data/proof_{}_{}.dat", T::NAME, params.k()); - let instance_path = format!("./data/instances_{}_{}.dat", T::NAME, params.k()); + let path = format!("./data/proof_{}_{}.dat", T::name(), params.k()); + let instance_path = format!("./data/instances_{}_{}.dat", T::name(), params.k()); let cached_instances = read_instances::(instance_path.as_str()); + #[cfg(feature = "serialize")] if cached_instances.is_some() && Path::new(path.as_str()).exists() && cached_instances.unwrap() == instances @@ -677,6 +679,23 @@ pub fn create_snark_shplonk( end_timer!(proof_time); proof } + #[cfg(not(feature = "serialize"))] + { + let proof_time = start_timer!(|| "create proof"); + let mut transcript = PoseidonTranscript::, _>::init(Vec::new()); + create_proof::, ProverSHPLONK<_>, ChallengeScalar<_>, _, _, _>( + params, + &pk, + &circuits, + instances2.as_slice(), + &mut ChaCha20Rng::from_seed(Default::default()), + &mut transcript, + ) + .unwrap(); + let proof = transcript.finalize(); + end_timer!(proof_time); + proof + } }; let verify_time = start_timer!(|| "verify proof"); From ce657ffb7a08ae07ff22c27037d4078fc28503e3 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Tue, 25 Oct 2022 16:23:51 -0700 Subject: [PATCH 08/28] fix: `load_verify_circuit_degree` now takes in environmental variable `VERIFY_CONFIG` --- examples/evm-verifier-with-accumulator.rs | 16 +++++++++++----- src/system/halo2/aggregation.rs | 13 +++++++------ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index eb4f19c6..2b913c90 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -39,7 +39,11 @@ use plonk_verifier::{ verifier::{self, PlonkVerifier}, }; use rand::rngs::OsRng; -use std::{fs, io::Cursor, rc::Rc}; +use std::{ + fs::{self, File}, + io::Cursor, + rc::Rc, +}; const LIMBS: usize = 3; const BITS: usize = 88; @@ -296,11 +300,12 @@ fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) } pub fn load_verify_circuit_degree() -> u32 { - let path = "./configs/verify_circuit.config"; - let params_str = - std::fs::read_to_string(path).expect(format!("{} file should exist", path).as_str()); + let path = std::env::var("VERIFY_CONFIG").expect("export VERIFY_CONFIG with config path"); let params: plonk_verifier::system::halo2::Halo2VerifierCircuitConfigParams = - serde_json::from_str(params_str.as_str()).unwrap(); + serde_json::from_reader( + File::open(path.clone()).expect(format!("{} file should exist", path).as_str()), + ) + .unwrap(); params.degree } @@ -314,6 +319,7 @@ impl TargetCircuit for StandardPlonk { } fn main() { + std::env::set_var("VERIFY_CONFIG", "./configs/verify_circuit.config"); let k = load_verify_circuit_degree(); let params = gen_srs(k); diff --git a/src/system/halo2/aggregation.rs b/src/system/halo2/aggregation.rs index 17197063..40a154f7 100644 --- a/src/system/halo2/aggregation.rs +++ b/src/system/halo2/aggregation.rs @@ -47,7 +47,7 @@ use num_bigint::BigUint; use num_traits::Num; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; use std::{ - fs::{self, File}, + fs::File, io::{BufReader, BufWriter, Cursor, Read, Write}, path::Path, rc::Rc, @@ -440,10 +440,11 @@ impl Circuit for AggregationCircuit { } fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - let path = "./configs/verify_circuit.config"; - let params_str = fs::read_to_string(path).expect(format!("{} should exist", path).as_str()); - let params: Halo2VerifierCircuitConfigParams = - serde_json::from_str(params_str.as_str()).unwrap(); + let path = std::env::var("VERIFY_CONFIG").expect("export VERIFY_CONFIG with config path"); + let params: Halo2VerifierCircuitConfigParams = serde_json::from_reader( + File::open(path.as_str()).expect(format!("{} file should exist", path).as_str()), + ) + .unwrap(); Halo2VerifierCircuitConfig::configure(meta, params) } @@ -612,7 +613,7 @@ pub trait TargetCircuit { } // this is a toggle that should match the fork of halo2_proofs you are using -// the current default in PSE/main is `false`, before 2022_10_22 it was `false`: +// the current default in PSE/main is `false`, before 2022_10_22 it was `true`: // see https://github.com/privacy-scaling-explorations/halo2/pull/96/files pub const KZG_QUERY_INSTANCE: bool = false; From 2cd8b9d1e57949979db1cbfa8506518e63d27e93 Mon Sep 17 00:00:00 2001 From: Han Date: Fri, 28 Oct 2022 07:26:06 -0700 Subject: [PATCH 09/28] Generalized `Halo2Loader` (#12) * feat: generalize `Protocol` for further usage * feat: add `EccInstruction::{fixed_base_msm,variable_base_msm,sum_with_const}` * chore: move `rand_chacha` as dev dependency --- Cargo.toml | 2 +- examples/evm-verifier-with-accumulator.rs | 11 +- examples/evm-verifier.rs | 3 +- src/lib.rs | 10 +- src/loader.rs | 33 ++-- src/loader/evm/loader.rs | 148 ++++++++------- src/loader/halo2.rs | 42 +++++ src/loader/halo2/loader.rs | 220 +++++++++++++--------- src/loader/halo2/shim.rs | 59 ++++-- src/loader/native.rs | 20 +- src/pcs/kzg/multiopen/bdfg21.rs | 4 +- src/system/halo2.rs | 9 +- src/system/halo2/test/kzg.rs | 2 +- src/system/halo2/test/kzg/evm.rs | 7 +- src/system/halo2/test/kzg/halo2.rs | 6 +- src/system/halo2/transcript/evm.rs | 4 +- src/system/halo2/transcript/halo2.rs | 14 +- src/util/hash/poseidon.rs | 2 +- src/util/msm.rs | 12 +- src/util/protocol.rs | 32 ++++ src/verifier.rs | 6 +- src/verifier/plonk.rs | 29 +-- 22 files changed, 423 insertions(+), 252 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index de037c4c..803beade 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,6 @@ num-bigint = "0.4.3" num-integer = "0.1.45" num-traits = "0.2.15" rand = "0.8" -rand_chacha = "0.3.1" halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } # system_halo2 @@ -25,6 +24,7 @@ halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2 poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true } [dev-dependencies] +rand_chacha = "0.3.1" paste = "1.0.7" # system_halo2 diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 69def21e..66fa8ddf 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -280,12 +280,12 @@ mod aggregation { let accumulators = snarks .iter() .flat_map(|snark| { + let protocol = snark.protocol.loaded(loader); let instances = assign_instances(&snark.instances); let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); - let proof = - Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); - Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap() }) .collect_vec(); @@ -555,11 +555,12 @@ fn gen_aggregation_evm_verifier( vk, Config::kzg() .with_num_instance(num_instance.clone()) - .with_accumulator_indices(accumulator_indices), + .with_accumulator_indices(Some(accumulator_indices)), ); let loader = EvmLoader::new::(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances(num_instance); let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); diff --git a/examples/evm-verifier.rs b/examples/evm-verifier.rs index b51a9a30..9ed70b36 100644 --- a/examples/evm-verifier.rs +++ b/examples/evm-verifier.rs @@ -216,7 +216,8 @@ fn gen_evm_verifier( ); let loader = EvmLoader::new::(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances(num_instance); let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 4e8da5fb..9749924d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,10 +20,14 @@ pub enum Error { } #[derive(Clone, Debug)] -pub struct Protocol { +pub struct Protocol +where + C: util::arithmetic::CurveAffine, + L: loader::Loader, +{ // Common description pub domain: util::arithmetic::Domain, - pub preprocessed: Vec, + pub preprocessed: Vec, pub num_instance: Vec, pub num_witness: Vec, pub num_challenge: Vec, @@ -31,7 +35,7 @@ pub struct Protocol { pub queries: Vec, pub quotient: util::protocol::QuotientPolynomial, // Minor customization - pub transcript_initial_state: Option, + pub transcript_initial_state: Option, pub instance_committing_key: Option>, pub linearization: Option, pub accumulator_indices: Vec>, diff --git a/src/loader.rs b/src/loader.rs index 8c39bae0..5a040f9f 100644 --- a/src/loader.rs +++ b/src/loader.rs @@ -19,15 +19,6 @@ pub trait LoadedEcPoint: Clone + Debug + PartialEq { type Loader: Loader; fn loader(&self) -> &Self::Loader; - - fn multi_scalar_multiplication( - pairs: impl IntoIterator< - Item = ( - >::LoadedScalar, - Self, - ), - >, - ) -> Self; } pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { @@ -43,15 +34,6 @@ pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { FieldOps::invert(self) } - fn batch_invert<'a>(values: impl IntoIterator) - where - Self: 'a, - { - values - .into_iter() - .for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone())) - } - fn pow_const(&self, mut exp: u64) -> Self { assert!(exp > 0); @@ -102,6 +84,12 @@ pub trait EcPointLoader { lhs: &Self::LoadedEcPoint, rhs: &Self::LoadedEcPoint, ) -> Result<(), Error>; + + fn multi_scalar_multiplication( + pairs: &[(Self::LoadedScalar, Self::LoadedEcPoint)], + ) -> Self::LoadedEcPoint + where + Self: ScalarLoader; } pub trait ScalarLoader { @@ -226,6 +214,15 @@ pub trait ScalarLoader { .iter() .fold(self.load_one(), |acc, value| acc * *value) } + + fn batch_invert<'a>(values: impl IntoIterator) + where + Self::LoadedScalar: 'a, + { + values + .into_iter() + .for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone())) + } } pub trait Loader: diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs index 06ab7fd8..7d1dbb94 100644 --- a/src/loader/evm/loader.rs +++ b/src/loader/evm/loader.rs @@ -596,17 +596,6 @@ where fn loader(&self) -> &Rc { &self.loader } - - fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { - pairs - .into_iter() - .map(|(scalar, ec_point)| match scalar.value { - Value::Constant(constant) if constant == U256::one() => ec_point, - _ => ec_point.loader.ec_point_scalar_mul(&ec_point, &scalar), - }) - .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point)) - .unwrap() - } } #[derive(Clone)] @@ -759,73 +748,12 @@ impl> LoadedScalar for Scalar { fn loader(&self) -> &Rc { &self.loader } - - fn batch_invert<'a>(values: impl IntoIterator) { - let values = values.into_iter().collect_vec(); - let loader = &values.first().unwrap().loader; - let products = iter::once(values[0].clone()) - .chain( - iter::repeat_with(|| loader.allocate(0x20)) - .map(|ptr| loader.scalar(Value::Memory(ptr))) - .take(values.len() - 1), - ) - .collect_vec(); - - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); - } - - loader.push(products.first().unwrap()); - for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { - loader.push(value); - loader.code.borrow_mut().mulmod(); - if idx < values.len() - 2 { - loader.code.borrow_mut().dup(0); - } - loader.code.borrow_mut().push(product.ptr()).mstore(); - } - - let inv = loader.invert(products.last().unwrap()); - - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); - } - - loader.push(&inv); - for (value, product) in values.iter().rev().zip( - products - .iter() - .rev() - .skip(1) - .map(Some) - .chain(iter::once(None)), - ) { - if let Some(product) = product { - loader.push(value); - loader - .code - .borrow_mut() - .dup(2) - .dup(2) - .push(product.ptr()) - .mload() - .mulmod() - .push(value.ptr()) - .mstore() - .mulmod(); - } else { - loader.code.borrow_mut().push(value.ptr()).mstore(); - } - } - } } impl EcPointLoader for Rc where C: CurveAffine, - C::Scalar: PrimeField, + C::ScalarExt: PrimeField, { type LoadedEcPoint = EcPoint; @@ -839,6 +767,19 @@ where fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) -> Result<(), Error> { unimplemented!() } + + fn multi_scalar_multiplication( + pairs: &[(>::LoadedScalar, EcPoint)], + ) -> EcPoint { + pairs + .iter() + .map(|(scalar, ec_point)| match scalar.value { + Value::Constant(constant) if U256::one() == constant => ec_point.clone(), + _ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar), + }) + .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point)) + .unwrap() + } } impl> ScalarLoader for Rc { @@ -977,6 +918,67 @@ impl> ScalarLoader for Rc { self.scalar(Value::Memory(ptr)) } + + fn batch_invert<'a>(values: impl IntoIterator) { + let values = values.into_iter().collect_vec(); + let loader = &values.first().unwrap().loader; + let products = iter::once(values[0].clone()) + .chain( + iter::repeat_with(|| loader.allocate(0x20)) + .map(|ptr| loader.scalar(Value::Memory(ptr))) + .take(values.len() - 1), + ) + .collect_vec(); + + loader.code.borrow_mut().push(loader.scalar_modulus); + for _ in 2..values.len() { + loader.code.borrow_mut().dup(0); + } + + loader.push(products.first().unwrap()); + for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { + loader.push(value); + loader.code.borrow_mut().mulmod(); + if idx < values.len() - 2 { + loader.code.borrow_mut().dup(0); + } + loader.code.borrow_mut().push(product.ptr()).mstore(); + } + + let inv = loader.invert(products.last().unwrap()); + + loader.code.borrow_mut().push(loader.scalar_modulus); + for _ in 2..values.len() { + loader.code.borrow_mut().dup(0); + } + + loader.push(&inv); + for (value, product) in values.iter().rev().zip( + products + .iter() + .rev() + .skip(1) + .map(Some) + .chain(iter::once(None)), + ) { + if let Some(product) = product { + loader.push(value); + loader + .code + .borrow_mut() + .dup(2) + .dup(2) + .push(product.ptr()) + .mload() + .mulmod() + .push(value.ptr()) + .mstore() + .mulmod(); + } else { + loader.code.borrow_mut().push(value.ptr()).mstore(); + } + } + } } impl Loader for Rc diff --git a/src/loader/halo2.rs b/src/loader/halo2.rs index 9eae74e3..39e95c23 100644 --- a/src/loader/halo2.rs +++ b/src/loader/halo2.rs @@ -1,3 +1,7 @@ +use crate::{util::arithmetic::CurveAffine, Protocol}; +use halo2_proofs::circuit; +use std::rc::Rc; + pub(crate) mod loader; mod shim; @@ -27,3 +31,41 @@ mod util { impl>> Valuetools for I {} } + +impl Protocol +where + C: CurveAffine, +{ + pub fn loaded_preprocessed_as_witness<'a, EccChip: EccInstructions<'a, C>>( + &self, + loader: &Rc>, + ) -> Protocol>> { + let preprocessed = self + .preprocessed + .iter() + .map(|preprocessed| loader.assign_ec_point(circuit::Value::known(*preprocessed))) + .collect(); + let transcript_initial_state = + self.transcript_initial_state + .as_ref() + .map(|transcript_initial_state| { + loader.assign_scalar(circuit::Value::known( + loader.scalar_chip().integer(*transcript_initial_state), + )) + }); + Protocol { + domain: self.domain.clone(), + preprocessed, + num_instance: self.num_instance.clone(), + num_witness: self.num_witness.clone(), + num_challenge: self.num_challenge.clone(), + evaluations: self.evaluations.clone(), + queries: self.queries.clone(), + quotient: self.quotient.clone(), + transcript_initial_state, + instance_committing_key: self.instance_committing_key.clone(), + linearization: self.linearization.clone(), + accumulator_indices: self.accumulator_indices.clone(), + } + } +} diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index 0d288ce7..d446c716 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -11,9 +11,7 @@ use crate::{ use halo2_proofs::circuit; use std::{ cell::{Ref, RefCell, RefMut}, - collections::btree_map::{BTreeMap, Entry}, fmt::{self, Debug}, - iter, marker::PhantomData, ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Sub, SubAssign}, rc::Rc, @@ -25,7 +23,6 @@ pub struct Halo2Loader<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { ctx: RefCell, num_scalar: RefCell, num_ec_point: RefCell, - const_ec_point: RefCell>>, _marker: PhantomData, #[cfg(test)] row_meterings: RefCell>, @@ -38,7 +35,6 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc ctx: RefCell::new(ctx), num_scalar: RefCell::default(), num_ec_point: RefCell::default(), - const_ec_point: RefCell::default(), #[cfg(test)] row_meterings: RefCell::default(), _marker: PhantomData, @@ -61,16 +57,16 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc self.ctx.borrow() } - pub(crate) fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { + pub fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { self.ctx.borrow_mut() } - pub fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> Scalar<'a, C, EccChip> { + fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> Scalar<'a, C, EccChip> { let assigned = self .scalar_chip() .assign_constant(&mut self.ctx_mut(), constant) .unwrap(); - self.scalar(Value::Assigned(assigned)) + self.scalar_from_assigned(assigned) } pub fn assign_scalar( @@ -81,10 +77,17 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .scalar_chip() .assign_integer(&mut self.ctx_mut(), scalar) .unwrap(); + self.scalar_from_assigned(assigned) + } + + pub fn scalar_from_assigned( + self: &Rc, + assigned: EccChip::AssignedScalar, + ) -> Scalar<'a, C, EccChip> { self.scalar(Value::Assigned(assigned)) } - pub(crate) fn scalar( + fn scalar( self: &Rc, value: Value, ) -> Scalar<'a, C, EccChip> { @@ -97,23 +100,12 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc } } - pub fn assign_const_ec_point(self: &Rc, constant: C) -> EcPoint<'a, C, EccChip> { - let coordinates = constant.coordinates().unwrap(); - match self - .const_ec_point - .borrow_mut() - .entry((*coordinates.x(), *coordinates.y())) - { - Entry::Occupied(entry) => entry.get().clone(), - Entry::Vacant(entry) => { - let assigned = self - .ecc_chip() - .assign_point(&mut self.ctx_mut(), circuit::Value::known(constant)) - .unwrap(); - let ec_point = self.ec_point(assigned); - entry.insert(ec_point).clone() - } - } + fn assign_const_ec_point(self: &Rc, constant: C) -> EcPoint<'a, C, EccChip> { + self.ec_point_from_assigned( + self.ecc_chip() + .assign_constant(&mut self.ctx_mut(), constant) + .unwrap(), + ) } pub fn assign_ec_point( @@ -124,16 +116,26 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .ecc_chip() .assign_point(&mut self.ctx_mut(), ec_point) .unwrap(); - self.ec_point(assigned) + self.ec_point_from_assigned(assigned) + } + + pub fn ec_point_from_assigned( + self: &Rc, + assigned: EccChip::AssignedEcPoint, + ) -> EcPoint<'a, C, EccChip> { + self.ec_point(Value::Assigned(assigned)) } - fn ec_point(self: &Rc, assigned: EccChip::AssignedEcPoint) -> EcPoint<'a, C, EccChip> { + fn ec_point( + self: &Rc, + value: Value, + ) -> EcPoint<'a, C, EccChip> { let index = *self.num_ec_point.borrow(); *self.num_ec_point.borrow_mut() += 1; EcPoint { loader: self.clone(), index, - assigned, + value, } } @@ -305,7 +307,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> &self.loader } - pub(crate) fn assigned(&self) -> EccChip::AssignedScalar { + pub fn assigned(&self) -> EccChip::AssignedScalar { match &self.value { Value::Constant(constant) => self.loader.assign_const_scalar(*constant).assigned(), Value::Assigned(assigned) => assigned.clone(), @@ -451,12 +453,15 @@ impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign<&'b Self pub struct EcPoint<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { loader: Rc>, index: usize, - assigned: EccChip::AssignedEcPoint, + value: Value, } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip> { pub fn assigned(&self) -> EccChip::AssignedEcPoint { - self.assigned.clone() + match &self.value { + Value::Constant(constant) => self.loader.assign_const_ec_point(*constant).assigned(), + Value::Assigned(assigned) => assigned.clone(), + } } } @@ -474,65 +479,13 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedEcPoint fn loader(&self) -> &Self::Loader { &self.loader } - - fn multi_scalar_multiplication( - pairs: impl IntoIterator, Self)>, - ) -> Self { - let pairs = pairs.into_iter().collect_vec(); - let loader = &pairs[0].0.loader; - - let (non_scaled, scaled) = pairs.iter().fold( - (Vec::new(), Vec::new()), - |(mut non_scaled, mut scaled), (scalar, ec_point)| { - if matches!(scalar.value, Value::Constant(constant) if constant == C::Scalar::one()) - { - non_scaled.push(ec_point.assigned()); - } else { - scaled.push((ec_point.assigned(), scalar.assigned())) - } - (non_scaled, scaled) - }, - ); - - let output = iter::empty() - .chain(if scaled.is_empty() { - None - } else { - Some( - loader - .ecc_chip - .borrow_mut() - .multi_scalar_multiplication(&mut loader.ctx_mut(), scaled) - .unwrap(), - ) - }) - .chain(non_scaled) - .reduce(|acc, ec_point| { - EccInstructions::add( - loader.ecc_chip().deref(), - &mut loader.ctx_mut(), - &acc, - &ec_point, - ) - .unwrap() - }) - .map(|output| { - loader - .ecc_chip() - .normalize(&mut loader.ctx_mut(), &output) - .unwrap() - }) - .unwrap(); - - loader.ec_point(output) - } } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for EcPoint<'a, C, EccChip> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("EcPoint") .field("index", &self.index) - .field("assigned", &self.assigned) + .field("value", &self.value) .finish() } } @@ -596,7 +549,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader type LoadedEcPoint = EcPoint<'a, C, EccChip>; fn ec_point_load_const(&self, ec_point: &C) -> EcPoint<'a, C, EccChip> { - self.assign_const_ec_point(*ec_point) + self.ec_point(Value::Constant(*ec_point)) } fn ec_point_assert_eq( @@ -605,9 +558,100 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader lhs: &EcPoint<'a, C, EccChip>, rhs: &EcPoint<'a, C, EccChip>, ) -> Result<(), crate::Error> { - self.ecc_chip() - .assert_equal(&mut self.ctx_mut(), &lhs.assigned(), &rhs.assigned()) - .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + match (&lhs.value, &rhs.value) { + (Value::Constant(lhs), Value::Constant(rhs)) => { + assert_eq!(lhs, rhs); + Ok(()) + } + (Value::Constant(constant), Value::Assigned(assigned)) + | (Value::Assigned(assigned), Value::Constant(constant)) => { + let constant = self.assign_const_ec_point(*constant).assigned(); + self.ecc_chip() + .assert_equal(&mut self.ctx_mut(), assigned, &constant) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + } + (Value::Assigned(lhs), Value::Assigned(rhs)) => self + .ecc_chip() + .assert_equal(&mut self.ctx_mut(), lhs, rhs) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())), + } + } + + fn multi_scalar_multiplication( + pairs: &[( + >::LoadedScalar, + EcPoint<'a, C, EccChip>, + )], + ) -> EcPoint<'a, C, EccChip> { + let loader = &pairs[0].0.loader; + + let (constant, fixed_base, variable_base_non_scaled, variable_base_scaled) = + pairs.iter().fold( + (C::identity(), Vec::new(), Vec::new(), Vec::new()), + |( + mut constant, + mut fixed_base, + mut variable_base_non_scaled, + mut variable_base_scaled, + ), + (scalar, ec_point)| { + match (&ec_point.value, &scalar.value) { + (Value::Constant(ec_point), Value::Constant(scalar)) => { + constant = (*ec_point * scalar + constant).into() + } + (Value::Constant(ec_point), Value::Assigned(scalar)) => { + fixed_base.push((scalar.clone(), *ec_point)) + } + (Value::Assigned(ec_point), Value::Constant(scalar)) + if scalar.eq(&C::Scalar::one()) => + { + variable_base_non_scaled.push(ec_point.clone()); + } + (Value::Assigned(ec_point), _) => { + variable_base_scaled.push((scalar.assigned(), ec_point.clone())) + } + }; + ( + constant, + fixed_base, + variable_base_non_scaled, + variable_base_scaled, + ) + }, + ); + + let fixed_base_msm = (!fixed_base.is_empty()).then(|| { + loader + .ecc_chip + .borrow_mut() + .fixed_base_msm(&mut loader.ctx_mut(), &fixed_base) + .unwrap() + }); + let variable_base_msm = (!variable_base_scaled.is_empty()).then(|| { + loader + .ecc_chip + .borrow_mut() + .variable_base_msm(&mut loader.ctx_mut(), &variable_base_scaled) + .unwrap() + }); + let output = loader + .ecc_chip() + .sum_with_const( + &mut loader.ctx_mut(), + &variable_base_non_scaled + .into_iter() + .chain(fixed_base_msm) + .chain(variable_base_msm) + .collect_vec(), + constant, + ) + .unwrap(); + let normalized = loader + .ecc_chip() + .normalize(&mut loader.ctx_mut(), &output) + .unwrap(); + + loader.ec_point_from_assigned(normalized) } } diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 67f06cc9..97921d24 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -98,17 +98,23 @@ pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { point: Value, ) -> Result; - fn add( + fn sum_with_const( &self, ctx: &mut Self::Context, - p0: &Self::AssignedEcPoint, - p1: &Self::AssignedEcPoint, + values: &[Self::AssignedEcPoint], + constant: C, ) -> Result; - fn multi_scalar_multiplication( + fn fixed_base_msm( &mut self, ctx: &mut Self::Context, - pairs: Vec<(Self::AssignedEcPoint, Self::AssignedScalar)>, + pairs: &[(Self::AssignedScalar, C)], + ) -> Result; + + fn variable_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[(Self::AssignedScalar, Self::AssignedEcPoint)], ) -> Result; fn normalize( @@ -146,6 +152,7 @@ mod halo2_wrong { AssignedPoint, BaseFieldEccChip, }; use rand::rngs::OsRng; + use std::iter; impl<'a, F: FieldExt> Context for RegionCtx<'a, F> { fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { @@ -348,21 +355,51 @@ mod halo2_wrong { self.assign_point(ctx, point) } - fn add( + fn sum_with_const( &self, ctx: &mut Self::Context, - p0: &Self::AssignedEcPoint, - p1: &Self::AssignedEcPoint, + values: &[Self::AssignedEcPoint], + constant: C, + ) -> Result { + if values.is_empty() { + return self.assign_constant(ctx, constant); + } + + iter::empty() + .chain( + (!bool::from(constant.is_identity())) + .then(|| self.assign_constant(ctx, constant)), + ) + .chain(values.iter().cloned().map(Ok)) + .reduce(|acc, ec_point| self.add(ctx, &acc?, &ec_point?)) + .unwrap() + } + + fn fixed_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[(Self::AssignedScalar, C)], ) -> Result { - self.add(ctx, p0, p1) + // FIXME: Implement fixed base MSM in halo2_wrong + let pairs = pairs + .iter() + .map(|(scalar, base)| { + Ok::<_, Error>((scalar.clone(), self.assign_constant(ctx, *base)?)) + }) + .collect::, _>>()?; + self.variable_base_msm(ctx, &pairs) } - fn multi_scalar_multiplication( + fn variable_base_msm( &mut self, ctx: &mut Self::Context, - pairs: Vec<(Self::AssignedEcPoint, Self::AssignedScalar)>, + pairs: &[(Self::AssignedScalar, Self::AssignedEcPoint)], ) -> Result { const WINDOW_SIZE: usize = 3; + let pairs = pairs + .iter() + .map(|(scalar, base)| (base.clone(), scalar.clone())) + .collect_vec(); match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { Err(_) => { if self.assign_aux(ctx, WINDOW_SIZE, pairs.len()).is_err() { diff --git a/src/loader/native.rs b/src/loader/native.rs index 6bf9f7c4..1451ff76 100644 --- a/src/loader/native.rs +++ b/src/loader/native.rs @@ -19,15 +19,6 @@ impl LoadedEcPoint for C { fn loader(&self) -> &NativeLoader { &LOADER } - - fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { - pairs - .into_iter() - .map(|(scalar, base)| base * scalar) - .reduce(|acc, value| acc + value) - .unwrap() - .to_affine() - } } impl FieldOps for F { @@ -61,6 +52,17 @@ impl EcPointLoader for NativeLoader { .then_some(()) .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) } + + fn multi_scalar_multiplication( + pairs: &[(>::LoadedScalar, C)], + ) -> C { + pairs + .iter() + .map(|(scalar, base)| *base * scalar) + .reduce(|acc, value| acc + value) + .unwrap() + .to_affine() + } } impl ScalarLoader for NativeLoader { diff --git a/src/pcs/kzg/multiopen/bdfg21.rs b/src/pcs/kzg/multiopen/bdfg21.rs index 287700d7..7f70ca61 100644 --- a/src/pcs/kzg/multiopen/bdfg21.rs +++ b/src/pcs/kzg/multiopen/bdfg21.rs @@ -203,8 +203,8 @@ fn query_set_coeffs>( }) .collect_vec(); - T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); - T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); coeffs.iter_mut().for_each(QuerySetCoeff::evaluate); coeffs diff --git a/src/system/halo2.rs b/src/system/halo2.rs index bf3e4091..3d7cf140 100644 --- a/src/system/halo2.rs +++ b/src/system/halo2.rs @@ -70,8 +70,11 @@ impl Config { self } - pub fn with_accumulator_indices(mut self, accumulator_indices: Vec<(usize, usize)>) -> Self { - self.accumulator_indices = Some(accumulator_indices); + pub fn with_accumulator_indices( + mut self, + accumulator_indices: Option>, + ) -> Self { + self.accumulator_indices = accumulator_indices; self } } @@ -202,7 +205,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { degree - 1 }; - let num_phase = *cs.advice_column_phase().iter().max().unwrap() as usize + 1; + let num_phase = *cs.advice_column_phase().iter().max().unwrap_or(&0) as usize + 1; let remapping = |phase: Vec| { let num = phase.iter().fold(vec![0; num_phase], |mut num, phase| { num[*phase as usize] += 1; diff --git a/src/system/halo2/test/kzg.rs b/src/system/halo2/test/kzg.rs index 0b071175..3b7e1694 100644 --- a/src/system/halo2/test/kzg.rs +++ b/src/system/halo2/test/kzg.rs @@ -45,7 +45,7 @@ macro_rules! halo2_kzg_config { $crate::system::halo2::Config::kzg() .set_zk($zk) .with_num_proof($num_proof) - .with_accumulator_indices($accumulator_indices) + .with_accumulator_indices(Some($accumulator_indices)) }; } diff --git a/src/system/halo2/test/kzg/evm.rs b/src/system/halo2/test/kzg/evm.rs index 4ce850c1..4e57d369 100644 --- a/src/system/halo2/test/kzg/evm.rs +++ b/src/system/halo2/test/kzg/evm.rs @@ -37,16 +37,17 @@ macro_rules! halo2_kzg_evm_verify { let runtime_code = { let svk = $params.get_g()[0].into(); let dk = ($params.g2(), $params.s_g2()).into(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let protocol = $protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances( $instances .iter() .map(|instances| instances.len()) .collect_vec(), ); - let proof = <$plonk_verifier>::read_proof(&svk, $protocol, &instances, &mut transcript) + let proof = <$plonk_verifier>::read_proof(&svk, &protocol, &instances, &mut transcript) .unwrap(); - <$plonk_verifier>::verify(&svk, &dk, $protocol, &instances, &proof).unwrap(); + <$plonk_verifier>::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); loader.runtime_code() }; diff --git a/src/system/halo2/test/kzg/halo2.rs b/src/system/halo2/test/kzg/halo2.rs index 1bd332a0..18767c98 100644 --- a/src/system/halo2/test/kzg/halo2.rs +++ b/src/system/halo2/test/kzg/halo2.rs @@ -89,12 +89,12 @@ pub fn accumulate<'a>( let mut accumulators = snarks .iter() .flat_map(|snark| { + let protocol = snark.protocol.loaded(loader); let instances = assign_instances(&snark.instances); let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); - let proof = - Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); - Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap() }) .collect_vec(); diff --git a/src/system/halo2/transcript/evm.rs b/src/system/halo2/transcript/evm.rs index 77aca5cf..e461bc05 100644 --- a/src/system/halo2/transcript/evm.rs +++ b/src/system/halo2/transcript/evm.rs @@ -32,13 +32,13 @@ where C: CurveAffine, C::Scalar: PrimeField, { - pub fn new(loader: Rc) -> Self { + pub fn new(loader: &Rc) -> Self { let ptr = loader.allocate(0x20); assert_eq!(ptr, 0); let mut buf = MemoryChunk::new(ptr); buf.extend(0x20); Self { - loader, + loader: loader.clone(), stream: 0, buf, _marker: PhantomData, diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs index 43701bd7..bb7e8e23 100644 --- a/src/system/halo2/transcript/halo2.rs +++ b/src/system/halo2/transcript/halo2.rs @@ -1,11 +1,11 @@ use crate::{ loader::{ - halo2::{self, EcPoint, EccInstructions, Halo2Loader, IntegerInstructions, Scalar}, + halo2::{EcPoint, EccInstructions, Halo2Loader, IntegerInstructions, Scalar}, native::{self, NativeLoader}, Loader, ScalarLoader, }, util::{ - arithmetic::{fe_from_big, fe_to_big, CurveAffine, FieldExt, PrimeField}, + arithmetic::{fe_to_fe, CurveAffine, FieldExt, PrimeField}, hash::Poseidon, transcript::{Transcript, TranscriptRead, TranscriptWrite}, Itertools, @@ -57,10 +57,11 @@ impl< > PoseidonTranscript>, Value, T, RATE, R_F, R_P> { pub fn new(loader: &Rc>, stream: Value) -> Self { + let buf = Poseidon::new(loader, R_F, R_P); Self { loader: loader.clone(), stream, - buf: Poseidon::new(loader.clone(), R_F, R_P), + buf, _marker: PhantomData, } } @@ -99,7 +100,7 @@ impl< .map(|encoded| { encoded .into_iter() - .map(|encoded| self.loader.scalar(halo2::loader::Value::Assigned(encoded))) + .map(|encoded| self.loader.scalar_from_assigned(encoded)) .collect_vec() }) .map_err(|_| Error::Transcript(io::ErrorKind::Other, "".to_string()))?; @@ -160,7 +161,7 @@ impl = Option::from(ec_point.coordinates().map(|coordinates| { [coordinates.x(), coordinates.y()] .into_iter() - .map(|fe| fe_from_big(fe_to_big(*fe))) + .cloned() + .map(fe_to_fe) .collect_vec() })) .ok_or_else(|| { diff --git a/src/util/hash/poseidon.rs b/src/util/hash/poseidon.rs index 878b69ce..c0fc03a6 100644 --- a/src/util/hash/poseidon.rs +++ b/src/util/hash/poseidon.rs @@ -113,7 +113,7 @@ pub struct Poseidon { } impl, const T: usize, const RATE: usize> Poseidon { - pub fn new(loader: L::Loader, r_f: usize, r_p: usize) -> Self { + pub fn new(loader: &L::Loader, r_f: usize, r_p: usize) -> Self { Self { spec: Spec::new(r_f, r_p), state: State::new( diff --git a/src/util/msm.rs b/src/util/msm.rs index a7a3d45d..a15eab9d 100644 --- a/src/util/msm.rs +++ b/src/util/msm.rs @@ -1,6 +1,6 @@ use crate::{ loader::{LoadedEcPoint, Loader}, - util::arithmetic::CurveAffine, + util::{arithmetic::CurveAffine, Itertools}, }; use std::{ default::Default, @@ -71,11 +71,11 @@ where .loader() .ec_point_load_const(&gen) }); - L::LoadedEcPoint::multi_scalar_multiplication( - iter::empty() - .chain(self.constant.map(|constant| (constant, gen.unwrap()))) - .chain(self.scalars.into_iter().zip(self.bases.into_iter())), - ) + let pairs = iter::empty() + .chain(self.constant.map(|constant| (constant, gen.unwrap()))) + .chain(self.scalars.into_iter().zip(self.bases.into_iter())) + .collect_vec(); + L::multi_scalar_multiplication(&pairs) } pub fn scale(&mut self, factor: &L::LoadedScalar) { diff --git a/src/util/protocol.rs b/src/util/protocol.rs index bae363ff..e9ecd9f0 100644 --- a/src/util/protocol.rs +++ b/src/util/protocol.rs @@ -4,6 +4,7 @@ use crate::{ arithmetic::{CurveAffine, Domain, Field, Fraction, Rotation}, Itertools, }, + Protocol, }; use num_integer::Integer; use num_traits::One; @@ -15,6 +16,37 @@ use std::{ ops::{Add, Mul, Neg, Sub}, }; +impl Protocol +where + C: CurveAffine, +{ + pub fn loaded>(&self, loader: &L) -> Protocol { + let preprocessed = self + .preprocessed + .iter() + .map(|preprocessed| loader.ec_point_load_const(preprocessed)) + .collect(); + let transcript_initial_state = self + .transcript_initial_state + .as_ref() + .map(|transcript_initial_state| loader.load_const(transcript_initial_state)); + Protocol { + domain: self.domain.clone(), + preprocessed, + num_instance: self.num_instance.clone(), + num_witness: self.num_witness.clone(), + num_challenge: self.num_challenge.clone(), + evaluations: self.evaluations.clone(), + queries: self.queries.clone(), + quotient: self.quotient.clone(), + transcript_initial_state, + instance_committing_key: self.instance_committing_key.clone(), + linearization: self.linearization.clone(), + accumulator_indices: self.accumulator_indices.clone(), + } + } +} + #[derive(Clone, Copy, Debug)] pub enum CommonPolynomial { Identity, diff --git a/src/verifier.rs b/src/verifier.rs index 07529603..0eef23d2 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -20,7 +20,7 @@ where fn read_proof( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], transcript: &mut T, ) -> Result @@ -29,7 +29,7 @@ where fn succinct_verify( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], proof: &Self::Proof, ) -> Result, Error>; @@ -37,7 +37,7 @@ where fn verify( svk: &MOS::SuccinctVerifyingKey, dk: &MOS::DecidingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], proof: &Self::Proof, ) -> Result diff --git a/src/verifier/plonk.rs b/src/verifier/plonk.rs index 9e08e585..c3276af4 100644 --- a/src/verifier/plonk.rs +++ b/src/verifier/plonk.rs @@ -29,7 +29,7 @@ where fn read_proof( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], transcript: &mut T, ) -> Result @@ -41,7 +41,7 @@ where fn succinct_verify( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], proof: &Self::Proof, ) -> Result, Error> { @@ -52,7 +52,7 @@ where &proof.z, ); - L::LoadedScalar::batch_invert(common_poly_eval.denoms()); + L::batch_invert(common_poly_eval.denoms()); common_poly_eval.evaluate(); common_poly_eval @@ -96,9 +96,9 @@ where L: Loader, MOS: MultiOpenScheme, { - fn read( + pub fn read( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], transcript: &mut T, ) -> Result @@ -106,9 +106,8 @@ where T: TranscriptRead, AE: AccumulatorEncoding, { - let loader = transcript.loader(); if let Some(transcript_initial_state) = &protocol.transcript_initial_state { - transcript.common_scalar(&loader.load_const(transcript_initial_state))?; + transcript.common_scalar(transcript_initial_state)?; } if protocol.num_instance @@ -211,7 +210,7 @@ where }) } - fn empty_queries(protocol: &Protocol) -> Vec> { + pub fn empty_queries(protocol: &Protocol) -> Vec> { protocol .queries .iter() @@ -227,7 +226,7 @@ where fn queries( &self, - protocol: &Protocol, + protocol: &Protocol, mut evaluations: HashMap, ) -> Vec> { Self::empty_queries(protocol) @@ -244,7 +243,7 @@ where fn commitments( &self, - protocol: &Protocol, + protocol: &Protocol, common_poly_eval: &CommonPolynomialEvaluation, evaluations: &mut HashMap, ) -> Result>, Error> { @@ -254,7 +253,7 @@ where protocol .preprocessed .iter() - .map(|value| Msm::base(loader.ec_point_load_const(value))), + .map(|value| Msm::base(value.clone())), ) .chain( self.committed_instances @@ -357,7 +356,7 @@ where fn evaluations( &self, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], common_poly_eval: &CommonPolynomialEvaluation, ) -> Result, Error> { @@ -424,9 +423,13 @@ where } } -fn langranges(protocol: &Protocol, instances: &[Vec]) -> impl IntoIterator +fn langranges( + protocol: &Protocol, + instances: &[Vec], +) -> impl IntoIterator where C: CurveAffine, + L: Loader, { let instance_eval_lagrange = protocol.instance_committing_key.is_none().then(|| { let queries = { From 70d6b7340c998335e461130d766a6b4a1e3e93d6 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Fri, 28 Oct 2022 10:58:34 -0700 Subject: [PATCH 10/28] feat: implement `IntegerInstructions` and `EccInstructions` traits for `halo2-lib` --- Cargo.toml | 5 +- src/loader/halo2/shim.rs | 222 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 226 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index de037c4c..b6bec864 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,8 @@ ethereum_types = { package = "ethereum-types", version = "0.13.1", default-featu sha3 = { version = "0.10.1", optional = true } # loader_halo2 +halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_base", default-features = false, optional = true } +halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_ecc", default-features = false, optional = true } halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc", optional = true } poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true } @@ -28,6 +30,7 @@ poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", t paste = "1.0.7" # system_halo2 +halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_ecc", default-features = false } halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc" } # loader_evm @@ -39,7 +42,7 @@ tui = { version = "0.16.0", default-features = false, features = ["crossterm"] } default = ["loader_evm", "loader_halo2", "system_halo2"] loader_evm = ["dep:ethereum_types", "dep:sha3"] -loader_halo2 = ["dep:halo2_proofs", "dep:halo2_wrong_ecc", "dep:poseidon"] +loader_halo2 = ["dep:halo2_proofs", "dep:halo2_base", "halo2_ecc", "dep:halo2_wrong_ecc", "dep:poseidon"] system_halo2 = ["dep:halo2_proofs"] diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 67f06cc9..0af1b6cd 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -125,6 +125,228 @@ pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { ) -> Result<(), Error>; } +mod halo2_lib { + use crate::{ + loader::halo2::{Context, EccInstructions, IntegerInstructions}, + util::arithmetic::{CurveAffine, Field, FieldExt, PrimeField}, + }; + use halo2_base::{ + self, + gates::{flex_gate::FlexGateConfig, GateInstructions, RangeInstructions}, + AssignedValue, + QuantumCell::{Constant, Existing}, + }; + use halo2_ecc::{ + bigint::CRTInteger, + ecc::{fixed::FixedEccPoint, EccChip, EccPoint}, + fields::{fp::FpConfig, FieldChip}, + }; + use halo2_proofs::{ + circuit::{Cell, Value}, + plonk::Error, + }; + + pub type BaseFieldChip = FpConfig<::ScalarExt, ::Base>; + pub type AssignedInteger = CRTInteger<::ScalarExt>; + pub type AssignedEcPoint = EccPoint<::ScalarExt, AssignedInteger>; + + impl<'a, F: FieldExt> Context for halo2_base::Context<'a, F> { + fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { + self.region.constrain_equal(lhs, rhs) + } + + fn offset(&self) -> usize { + dbg!("using context offset"); + *self.advice_rows.values().flatten().max().unwrap() + } + } + + impl<'a, F: FieldExt> IntegerInstructions<'a, F> for FlexGateConfig { + type Context = halo2_base::Context<'a, F>; + type Integer = F; + type AssignedInteger = AssignedValue; + + fn integer(&self, scalar: F) -> Self::Integer { + scalar + } + + fn assign_integer( + &self, + ctx: &mut Self::Context, + integer: Value, + ) -> Result { + Ok(self.assign_witnesses(ctx, vec![integer])?.pop().unwrap()) + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + integer: F, + ) -> Result { + Ok(self + .assign_region(ctx, vec![Constant(integer)], vec![], None)? + .pop() + .unwrap()) + } + + fn sum_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F, Self::AssignedInteger)], + constant: F, + ) -> Result { + let mut a = Vec::with_capacity(values.len() + 1); + let mut b = Vec::with_capacity(values.len() + 1); + if constant != F::zero() { + a.push(Constant(F::one())); + b.push(Constant(constant)); + } + a.extend(values.iter().map(|(_, a)| Existing(a))); + b.extend(values.iter().map(|(c, _)| Constant(*c))); + let (_, _, sum) = self.inner_product(ctx, &a, &b)?; + Ok(sum) + } + + fn sum_products_with_coeff_and_const( + &self, + ctx: &mut Self::Context, + values: &[(F, Self::AssignedInteger, Self::AssignedInteger)], + constant: F, + ) -> Result { + match values.len() { + 0 => self.assign_constant(ctx, constant), + _ => { + let mut prods = Vec::with_capacity(values.len()); + for (c, a, b) in values.into_iter() { + let a = Existing(&a); + let b = Existing(&b); + prods.push((*c, a, b)); + } + self.sum_products_with_coeff_and_var(ctx, &prods, &Constant(constant)) + } + } + } + + fn sub( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result { + GateInstructions::sub(self, ctx, &Existing(a), &Existing(b)) + } + + fn neg( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result { + GateInstructions::neg(self, ctx, &Existing(a)) + } + + fn invert( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + ) -> Result { + // make sure scalar != 0 + let is_zero = self.is_zero(ctx, a)?; + self.assert_is_const(ctx, &is_zero, F::zero()); + GateInstructions::div_unsafe(self, ctx, &Constant(F::one()), &Existing(a)) + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedInteger, + b: &Self::AssignedInteger, + ) -> Result<(), Error> { + ctx.region.constrain_equal(a.cell(), b.cell()) + } + } + + impl<'a, C: CurveAffine> EccInstructions<'a, C> for EccChip<'a, C::Scalar, BaseFieldChip> { + type Context = halo2_base::Context<'a, C::Scalar>; + type ScalarChip = FlexGateConfig; + type AssignedEcPoint = AssignedEcPoint; + type Scalar = C::Scalar; + type AssignedScalar = AssignedValue; + + fn scalar_chip(&self) -> &Self::ScalarChip { + self.field_chip.range().gate() + } + + fn assign_constant( + &self, + ctx: &mut Self::Context, + point: C, + ) -> Result { + let fixed = FixedEccPoint::::from_g1( + &point, + self.field_chip.num_limbs, + self.field_chip.limb_bits, + ); + FixedEccPoint::assign(fixed, self.field_chip, ctx) + } + + fn assign_point( + &self, + ctx: &mut Self::Context, + point: Value, + ) -> Result { + let assigned = self.assign_point(ctx, point)?; + let is_on_curve_or_infinity = self.is_on_curve_or_infinity::(ctx, &assigned)?; + self.field_chip.range.gate.assert_is_const( + ctx, + &is_on_curve_or_infinity, + C::Scalar::one(), + ); + Ok(assigned) + } + + fn add( + &self, + ctx: &mut Self::Context, + p0: &Self::AssignedEcPoint, + p1: &Self::AssignedEcPoint, + ) -> Result { + self.add_unequal(ctx, p0, p1, true) + } + + fn multi_scalar_multiplication( + &mut self, + ctx: &mut Self::Context, + pairs: Vec<(Self::AssignedEcPoint, Self::AssignedScalar)>, + ) -> Result { + let (points, scalars): (Vec<_>, Vec<_>) = pairs.into_iter().unzip(); + self.multi_scalar_mult::( + ctx, + &points, + &scalars.into_iter().map(|scalar| vec![scalar]).collect(), + ::NUM_BITS as usize, + 4, // empirically clump factor of 4 seems to be best + ) + } + + fn normalize( + &self, + _: &mut Self::Context, + _: &Self::AssignedEcPoint, + ) -> Result { + unreachable!() + } + + fn assert_equal( + &self, + ctx: &mut Self::Context, + a: &Self::AssignedEcPoint, + b: &Self::AssignedEcPoint, + ) -> Result<(), Error> { + self.assert_equal(ctx, a, b) + } + } +} + mod halo2_wrong { use crate::{ loader::halo2::{Context, EccInstructions, IntegerInstructions}, From cf9d717d763cf4fae57ebbcd8e1400ad02c1618e Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Fri, 28 Oct 2022 16:55:46 -0700 Subject: [PATCH 11/28] feat: implement `EncodeNative` and working example `evm-verifier-with-accumulator` using `halo2-lib` * Does not yet use custom `fixed_base_msm` --- .github/workflows/ci.yaml | 51 ------- Cargo.toml | 2 + configs/verify_circuit.config | 1 + examples/evm-verifier-with-accumulator.rs | 167 ++++++++++++++-------- src/loader/halo2/loader.rs | 2 +- src/loader/halo2/shim.rs | 15 +- src/pcs/kzg/accumulator.rs | 159 ++++++++++++++------ src/system/halo2/transcript/halo2.rs | 29 ++-- 8 files changed, 258 insertions(+), 168 deletions(-) delete mode 100644 .github/workflows/ci.yaml create mode 100644 configs/verify_circuit.config diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml deleted file mode 100644 index a50958d0..00000000 --- a/.github/workflows/ci.yaml +++ /dev/null @@ -1,51 +0,0 @@ -name: CI - -on: - pull_request: - push: - branches: - - main - -jobs: - test: - name: Test - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - - name: Install toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - override: false - - - uses: Swatinem/rust-cache@v1 - with: - cache-on-failure: true - - - name: Run test - run: cargo test --all --all-features -- --nocapture - - - lint: - name: Lint - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - - name: Install toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - override: false - components: rustfmt, clippy - - - uses: Swatinem/rust-cache@v1 - with: - cache-on-failure: true - - - name: Run fmt - run: cargo fmt --all -- --check - - - name: Run clippy - run: cargo clippy --all --all-features --all-targets -- -D warnings diff --git a/Cargo.toml b/Cargo.toml index b6bec864..be965520 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,8 @@ num-bigint = "0.4.3" num-integer = "0.1.45" num-traits = "0.2.15" rand = "0.8" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" rand_chacha = "0.3.1" halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } diff --git a/configs/verify_circuit.config b/configs/verify_circuit.config new file mode 100644 index 00000000..ae8fc999 --- /dev/null +++ b/configs/verify_circuit.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":21,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 69def21e..8185797b 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -28,8 +28,8 @@ use plonk_verifier::{ use rand::rngs::OsRng; use std::{io::Cursor, rc::Rc}; -const LIMBS: usize = 4; -const BITS: usize = 68; +const LIMBS: usize = 3; +const BITS: usize = 88; type Pcs = Kzg; type As = KzgAs; @@ -164,12 +164,15 @@ mod application { mod aggregation { use super::{As, Plonk, BITS, LIMBS}; + use halo2_base::{AssignedValue, Context, ContextParams}; use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; + use halo2_ecc::ecc::EccChip; use halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{self, Circuit, ConstraintSystem}, + circuit::{Cell, Layouter, SimpleFloorPlanner, Value}, + plonk::{self, Circuit, Column, ConstraintSystem, Instance}, poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, }; + /* use halo2_wrong_ecc::{ integer::rns::Rns, maingate::{ @@ -178,6 +181,7 @@ mod aggregation { }, EccConfig, }; + */ use itertools::Itertools; use plonk_verifier::{ loader::{self, native::NativeLoader}, @@ -191,7 +195,7 @@ mod aggregation { Protocol, }; use rand::rngs::OsRng; - use std::{iter, rc::Rc}; + use std::{fs::File, iter, rc::Rc}; const T: usize = 5; const RATE: usize = 4; @@ -199,8 +203,10 @@ mod aggregation { const R_P: usize = 60; type Svk = KzgSuccinctVerifyingKey; - type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; - type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + type BaseFieldEccChip<'b> = halo2_ecc::ecc::BaseFieldEccChip<'b, G1Affine>; + type Halo2Loader<'a, 'b> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip<'b>>; + // type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; + // type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; pub type PoseidonTranscript = system::halo2::transcript::halo2::PoseidonTranscript; @@ -259,12 +265,12 @@ mod aggregation { } } - pub fn aggregate<'a>( + pub fn aggregate<'a, 'b>( svk: &Svk, - loader: &Rc>, + loader: &Rc>, snarks: &[SnarkWitness], as_proof: Value<&'_ [u8]>, - ) -> KzgAccumulator>> { + ) -> KzgAccumulator>> { let assign_instances = |instances: &[Vec>]| { instances .iter() @@ -299,40 +305,59 @@ mod aggregation { acccumulator } + #[derive(serde::Serialize, serde::Deserialize)] + pub struct AggregationConfigParams { + pub strategy: halo2_ecc::fields::fp::FpStrategy, + pub degree: u32, + pub num_advice: usize, + pub num_lookup_advice: usize, + pub num_fixed: usize, + pub lookup_bits: usize, + pub limb_bits: usize, + pub num_limbs: usize, + } + #[derive(Clone)] pub struct AggregationConfig { - main_gate_config: MainGateConfig, - range_config: RangeConfig, + pub base_field_config: halo2_ecc::fields::fp::FpConfig, + pub instance: Column, } impl AggregationConfig { - pub fn configure( - meta: &mut ConstraintSystem, - composition_bits: Vec, - overflow_bits: Vec, - ) -> Self { - let main_gate_config = MainGate::::configure(meta); - let range_config = - RangeChip::::configure(meta, &main_gate_config, composition_bits, overflow_bits); - AggregationConfig { - main_gate_config, - range_config, - } - } + pub fn configure(meta: &mut ConstraintSystem, params: AggregationConfigParams) -> Self { + assert!( + params.limb_bits == BITS && params.num_limbs == LIMBS, + "For now we fix limb_bits = {}, otherwise change code", + BITS + ); + let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( + meta, + params.strategy, + &[params.num_advice], + &[params.num_lookup_advice], + params.num_fixed, + params.lookup_bits, + params.limb_bits, + params.num_limbs, + halo2_base::utils::modulus::(), + "verifier".to_string(), + ); + + let instance = meta.instance_column(); + meta.enable_equality(instance); - pub fn main_gate(&self) -> MainGate { - MainGate::new(self.main_gate_config.clone()) + Self { + base_field_config, + instance, + } } - pub fn range_chip(&self) -> RangeChip { - RangeChip::new(self.range_config.clone()) + pub fn range(&self) -> &halo2_base::gates::range::RangeConfig { + &self.base_field_config.range } - pub fn ecc_chip(&self) -> BaseFieldEccChip { - BaseFieldEccChip::new(EccConfig::new( - self.range_config.clone(), - self.main_gate_config.clone(), - )) + pub fn ecc_chip(&self) -> halo2_ecc::ecc::BaseFieldEccChip<'_, G1Affine> { + EccChip::construct(&self.base_field_config) } } @@ -417,11 +442,14 @@ mod aggregation { } fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - AggregationConfig::configure( - meta, - vec![BITS / LIMBS], - Rns::::construct().overflow_lengths(), + let path = + std::env::var("VERIFY_CONFIG").expect("export VERIFY_CONFIG with config path"); + let params: AggregationConfigParams = serde_json::from_reader( + File::open(path.as_str()).expect(format!("{} file should exist", path).as_str()), ) + .unwrap(); + + AggregationConfig::configure(meta, params) } fn synthesize( @@ -429,36 +457,60 @@ mod aggregation { config: Self::Config, mut layouter: impl Layouter, ) -> Result<(), plonk::Error> { - let main_gate = config.main_gate(); - let range_chip = config.range_chip(); - - range_chip.load_table(&mut layouter)?; + config.range().load_lookup_table(&mut layouter)?; - let (lhs, rhs) = layouter.assign_region( + // Need to trick layouter to skip first pass in get shape mode + let mut first_pass = true; // assume using simple floor planner + let mut assigned_instances: Option> = None; + layouter.assign_region( || "", |region| { - let ctx = RegionCtx::new(region, 0); + if first_pass { + first_pass = false; + return Ok(()); + } + let ctx = Context::new( + region, + ContextParams { + num_advice: vec![( + config.base_field_config.range.context_id.clone(), + config.base_field_config.range.gate.num_advice, + )], + }, + ); let ecc_chip = config.ecc_chip(); let loader = Halo2Loader::new(ecc_chip, ctx); let KzgAccumulator { lhs, rhs } = aggregate(&self.svk, &loader, &self.snarks, self.as_proof()); - Ok((lhs.assigned(), rhs.assigned())) + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + + config.base_field_config.finalize(&mut loader.ctx_mut())?; + + let instances: Vec<_> = lhs + .x + .truncation + .limbs + .iter() + .chain(lhs.y.truncation.limbs.iter()) + .chain(rhs.x.truncation.limbs.iter()) + .chain(rhs.y.truncation.limbs.iter()) + .map(|assigned| assigned.cell()) + .collect(); + assigned_instances = Some(instances); + Ok(()) }, )?; - for (limb, row) in iter::empty() - .chain(lhs.x().limbs()) - .chain(lhs.y().limbs()) - .chain(rhs.x().limbs()) - .chain(rhs.y().limbs()) - .zip(0..) - { - main_gate.expose_public(layouter.namespace(|| ""), limb.into(), row)?; - } - - Ok(()) + Ok({ + // TODO: use less instances by following Scroll's strategy of keeping only last bit of y coordinate + let mut layouter = layouter.namespace(|| "expose"); + for (i, cell) in assigned_instances.unwrap().into_iter().enumerate() { + layouter.constrain_instance(cell, config.instance, i)?; + } + }) } } } @@ -592,7 +644,8 @@ fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) } fn main() { - let params = gen_srs(22); + std::env::set_var("VERIFY_CONFIG", "./configs/verify_circuit.config"); + let params = gen_srs(21); let params_app = { let mut params = params.clone(); params.downsize(8); diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index 0d288ce7..1ffaa924 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -61,7 +61,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc self.ctx.borrow() } - pub(crate) fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { + pub fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { self.ctx.borrow_mut() } diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 0af1b6cd..47a1d99c 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -138,17 +138,16 @@ mod halo2_lib { }; use halo2_ecc::{ bigint::CRTInteger, - ecc::{fixed::FixedEccPoint, EccChip, EccPoint}, - fields::{fp::FpConfig, FieldChip}, + ecc::{fixed::FixedEccPoint, BaseFieldEccChip, EccPoint}, + fields::FieldChip, }; use halo2_proofs::{ circuit::{Cell, Value}, plonk::Error, }; - pub type BaseFieldChip = FpConfig<::ScalarExt, ::Base>; - pub type AssignedInteger = CRTInteger<::ScalarExt>; - pub type AssignedEcPoint = EccPoint<::ScalarExt, AssignedInteger>; + type AssignedInteger = CRTInteger<::ScalarExt>; + type AssignedEcPoint = EccPoint<::ScalarExt, AssignedInteger>; impl<'a, F: FieldExt> Context for halo2_base::Context<'a, F> { fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { @@ -265,7 +264,7 @@ mod halo2_lib { } } - impl<'a, C: CurveAffine> EccInstructions<'a, C> for EccChip<'a, C::Scalar, BaseFieldChip> { + impl<'a, 'b, C: CurveAffine> EccInstructions<'a, C> for BaseFieldEccChip<'b, C> { type Context = halo2_base::Context<'a, C::Scalar>; type ScalarChip = FlexGateConfig; type AssignedEcPoint = AssignedEcPoint; @@ -331,9 +330,9 @@ mod halo2_lib { fn normalize( &self, _: &mut Self::Context, - _: &Self::AssignedEcPoint, + point: &Self::AssignedEcPoint, ) -> Result { - unreachable!() + Ok(point.clone()) } fn assert_equal( diff --git a/src/pcs/kzg/accumulator.rs b/src/pcs/kzg/accumulator.rs index 17c7bf83..034e231c 100644 --- a/src/pcs/kzg/accumulator.rs +++ b/src/pcs/kzg/accumulator.rs @@ -137,20 +137,19 @@ mod halo2 { Error, }; use halo2_proofs::circuit::Value; - use halo2_wrong_ecc::{maingate::AssignedValue, AssignedPoint}; use std::{iter, rc::Rc}; fn ec_point_from_assigned_limbs( - limbs: &[AssignedValue], + limbs: &[Value<&C::Scalar>], ) -> Value { assert_eq!(limbs.len(), 2 * LIMBS); let [x, y] = [&limbs[..LIMBS], &limbs[LIMBS..]].map(|limbs| { limbs .iter() - .map(|assigned| assigned.value()) + .map(|limb| limb.map(|x| *x)) .fold_zipped(Vec::new(), |mut acc, limb| { - acc.push(*limb); + acc.push(limb); acc }) .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) @@ -159,50 +158,124 @@ mod halo2 { x.zip(y).map(|(x, y)| C::from_xy(x, y).unwrap()) } - impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> - AccumulatorEncoding>, PCS> for LimbsEncoding - where - C: CurveAffine, - PCS: PolynomialCommitmentScheme< - C, - Rc>, - Accumulator = KzgAccumulator>>, - >, - EccChip: EccInstructions< - 'a, - C, - AssignedEcPoint = AssignedPoint<::Base, C::Scalar, LIMBS, BITS>, - AssignedScalar = AssignedValue, - >, - { - fn from_repr(limbs: Vec>) -> Result { - assert_eq!(limbs.len(), 4 * LIMBS); + /* + mod halo2_wrong { + use super::*; + use halo2_wrong_ecc::{maingate::AssignedValue, AssignedPoint}; - let loader = limbs[0].loader(); + impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> + AccumulatorEncoding>, PCS> + for LimbsEncoding + where + C: CurveAffine, + PCS: PolynomialCommitmentScheme< + C, + Rc>, + Accumulator = KzgAccumulator>>, + >, + EccChip: EccInstructions< + 'a, + C, + AssignedEcPoint = AssignedPoint<::Base, C::Scalar, LIMBS, BITS>, + AssignedScalar = AssignedValue, + >, + { + fn from_repr(limbs: Vec>) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); - let assigned_limbs = limbs.iter().map(|limb| limb.assigned()).collect_vec(); - let [lhs, rhs] = [&assigned_limbs[..2 * LIMBS], &assigned_limbs[2 * LIMBS..]].map( - |assigned_limbs| { - let ec_point = ec_point_from_assigned_limbs::<_, LIMBS, BITS>(assigned_limbs); - loader.assign_ec_point(ec_point) - }, - ); + let loader = limbs[0].loader(); - for (src, dst) in assigned_limbs.iter().zip( - iter::empty() - .chain(lhs.assigned().x().limbs()) - .chain(lhs.assigned().y().limbs()) - .chain(rhs.assigned().x().limbs()) - .chain(rhs.assigned().y().limbs()), - ) { - loader - .ctx_mut() - .constrain_equal(src.cell(), dst.as_ref().cell()) - .unwrap(); + let assigned_limbs = limbs.iter().map(|limb| limb.assigned()).collect_vec(); + let [lhs, rhs] = [&assigned_limbs[..2 * LIMBS], &assigned_limbs[2 * LIMBS..]].map( + |assigned_limbs| { + let ec_point = ec_point_from_assigned_limbs::<_, LIMBS, BITS>( + assigned_limbs + .iter() + .map(|assigned| assigned.value()) + .collect() + .as_slice(), + ); + loader.assign_ec_point(ec_point) + }, + ); + + for (src, dst) in assigned_limbs.iter().zip( + iter::empty() + .chain(lhs.assigned().x().limbs()) + .chain(lhs.assigned().y().limbs()) + .chain(rhs.assigned().x().limbs()) + .chain(rhs.assigned().y().limbs()), + ) { + loader + .ctx_mut() + .constrain_equal(src.cell(), dst.as_ref().cell()) + .unwrap(); + } + let accumulator = KzgAccumulator::new(lhs, rhs); + + Ok(accumulator) } - let accumulator = KzgAccumulator::new(lhs, rhs); + } + } + */ - Ok(accumulator) + mod halo2_lib { + use super::*; + use halo2_base::AssignedValue; + use halo2_ecc::{bigint::CRTInteger, ecc::EccPoint}; + + impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> + AccumulatorEncoding>, PCS> + for LimbsEncoding + where + C: CurveAffine, + PCS: PolynomialCommitmentScheme< + C, + Rc>, + Accumulator = KzgAccumulator>>, + >, + EccChip: EccInstructions< + 'a, + C, + AssignedEcPoint = EccPoint>, + AssignedScalar = AssignedValue, + >, + { + fn from_repr(limbs: Vec>) -> Result { + assert_eq!(limbs.len(), 4 * LIMBS); + + let loader = limbs[0].loader(); + + let assigned_limbs = limbs.iter().map(|limb| limb.assigned()).collect_vec(); + let [lhs, rhs] = [&assigned_limbs[..2 * LIMBS], &assigned_limbs[2 * LIMBS..]].map( + |assigned_limbs| { + let ec_point = ec_point_from_assigned_limbs::<_, LIMBS, BITS>( + assigned_limbs + .iter() + .map(|assigned| assigned.value()) + .collect::>() + .as_slice(), + ); + loader.assign_ec_point(ec_point) + }, + ); + + for (src, dst) in assigned_limbs.iter().zip( + iter::empty() + .chain(lhs.assigned().x.truncation.limbs) + .chain(lhs.assigned().y.truncation.limbs) + .chain(rhs.assigned().x.truncation.limbs) + .chain(rhs.assigned().y.truncation.limbs), + ) { + loader + .ctx_mut() + .constrain_equal(src.cell(), dst.cell()) + .unwrap(); + } + let accumulator = KzgAccumulator::new(lhs, rhs); + + Ok(accumulator) + } } } } diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs index 43701bd7..0afe487c 100644 --- a/src/system/halo2/transcript/halo2.rs +++ b/src/system/halo2/transcript/halo2.rs @@ -12,10 +12,7 @@ use crate::{ }, Error, }; -use halo2_proofs::{ - circuit::{AssignedCell, Value}, - transcript::EncodedChallenge, -}; +use halo2_proofs::{circuit::Value, transcript::EncodedChallenge}; use std::{ io::{self, Read, Write}, marker::PhantomData, @@ -27,7 +24,7 @@ pub trait EncodeNative<'a, C: CurveAffine, N: FieldExt>: EccInstructions<'a, C> &self, ctx: &mut Self::Context, ec_point: &Self::AssignedEcPoint, - ) -> Result>, Error>; + ) -> Result, Error>; } pub struct PoseidonTranscript< @@ -49,7 +46,7 @@ impl< 'a, C: CurveAffine, R: Read, - EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, + EccChip: EncodeNative<'a, C, C::Scalar>, const T: usize, const RATE: usize, const R_F: usize, @@ -70,7 +67,7 @@ impl< 'a, C: CurveAffine, R: Read, - EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, + EccChip: EncodeNative<'a, C, C::Scalar>, const T: usize, const RATE: usize, const R_F: usize, @@ -112,7 +109,7 @@ impl< 'a, C: CurveAffine, R: Read, - EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, + EccChip: EncodeNative<'a, C, C::Scalar>, const T: usize, const RATE: usize, const R_F: usize, @@ -416,6 +413,22 @@ impl< } } +mod halo2_lib { + use crate::system::halo2::transcript::halo2::EncodeNative; + use halo2_curves::CurveAffine; + use halo2_ecc::ecc::BaseFieldEccChip; + + impl<'a, 'b, C: CurveAffine> EncodeNative<'a, C, C::Scalar> for BaseFieldEccChip<'b, C> { + fn encode_native( + &self, + _: &mut Self::Context, + ec_point: &Self::AssignedEcPoint, + ) -> Result, crate::Error> { + Ok(vec![ec_point.x.native.clone(), ec_point.y.native.clone()]) + } + } +} + mod halo2_wrong { use crate::system::halo2::transcript::halo2::EncodeNative; use halo2_curves::CurveAffine; From fdcc55c09afafc2138fb1129b4d93e31e8af77bf Mon Sep 17 00:00:00 2001 From: han0110 Date: Fri, 28 Oct 2022 18:48:10 +0800 Subject: [PATCH 12/28] feat: add example `recursion` --- Cargo.toml | 4 + examples/recursion.rs | 873 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 877 insertions(+) create mode 100644 examples/recursion.rs diff --git a/Cargo.toml b/Cargo.toml index 803beade..de011f92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,3 +52,7 @@ required-features = ["loader_evm", "system_halo2"] [[example]] name = "evm-verifier-with-accumulator" required-features = ["loader_halo2", "loader_evm", "system_halo2"] + +[[example]] +name = "recursion" +required-features = ["loader_halo2", "system_halo2"] diff --git a/examples/recursion.rs b/examples/recursion.rs new file mode 100644 index 00000000..5c34d663 --- /dev/null +++ b/examples/recursion.rs @@ -0,0 +1,873 @@ +#![allow(clippy::type_complexity)] + +use common::*; +use halo2_curves::{ + bn256::{Bn256, Fq, Fr, G1Affine}, + group::ff::Field, + CurveAffine, +}; +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{ + self, create_proof, keygen_pk, keygen_vk, Circuit, ConstraintSystem, Error, ProvingKey, + Selector, VerifyingKey, + }, + poly::{ + commitment::ParamsProver, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + Rotation, VerificationStrategy, + }, +}; +use itertools::Itertools; +use plonk_verifier::{ + loader::{self, native::NativeLoader, ScalarLoader}, + pcs::{ + kzg::{Gwc19, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, + AccumulationScheme, AccumulationSchemeProver, + }, + system::{ + self, + halo2::{compile, Config}, + }, + util::{ + arithmetic::{fe_to_fe, fe_to_limbs}, + hash, + }, + verifier::{self, PlonkProof, PlonkVerifier}, + Protocol, +}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha20Rng; +use std::{fs, iter, marker::PhantomData, rc::Rc}; + +const LIMBS: usize = 4; +const BITS: usize = 68; +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +type Pcs = Kzg; +type Svk = KzgSuccinctVerifyingKey; +type As = KzgAs; +type Plonk = verifier::Plonk>; +type Poseidon = hash::Poseidon; +type PoseidonTranscript = + system::halo2::transcript::halo2::PoseidonTranscript; + +mod common { + use super::*; + use halo2_proofs::poly::commitment::Params; + use plonk_verifier::{cost::CostEstimation, util::transcript::TranscriptWrite}; + + pub struct Snark { + pub protocol: Protocol, + pub instances: Vec>, + pub proof: Vec, + } + + impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + Self { + protocol, + instances, + proof, + } + } + } + + impl From for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } + } + + #[derive(Clone)] + pub struct SnarkWitness { + pub protocol: Protocol, + pub instances: Vec>>, + pub proof: Value>, + } + + impl SnarkWitness { + pub fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + pub fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } + } + + pub trait CircuitExt: Circuit { + fn num_instance() -> Vec; + + fn instances(&self) -> Vec>; + + fn accumulator_indices() -> Option> { + None + } + } + + pub fn gen_srs(k: u32) -> ParamsKZG { + let path = format!("./examples/k-{}.srs", k); + match fs::File::open(path.as_str()) { + Ok(mut file) => ParamsKZG::read(&mut file).unwrap(), + Err(_) => { + let params = + ParamsKZG::::setup(k, ChaCha20Rng::from_seed(Default::default())); + params.write(&mut fs::File::create(path).unwrap()).unwrap(); + params + } + } + } + + pub fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() + } + + pub fn gen_proof>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + ) -> Vec { + if params.k() > 3 { + MockProver::run(params.k(), &circuit, instances.clone()) + .unwrap() + .assert_satisfied(); + } + + let instances = instances + .iter() + .map(|instances| instances.as_slice()) + .collect_vec(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + create_proof::< + KZGCommitmentScheme, + ProverGWC<_>, + _, + _, + PoseidonTranscript, + _, + >( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = PoseidonTranscript::::new(proof.as_slice()); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + plonk::verify_proof::<_, VerifierGWC<_>, _, PoseidonTranscript, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof + } + + pub fn gen_snark>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + ) -> Snark { + let protocol = compile( + params, + pk.get_vk(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + + let instances = circuit.instances(); + let proof = gen_proof(params, pk, circuit, instances.clone()); + + Snark::new(protocol, instances, proof) + } + + pub fn gen_dummy_snark>( + params: &ParamsKZG, + vk: Option<&VerifyingKey>, + ) -> Snark { + struct CsProxy(PhantomData<(F, C)>); + + impl> Circuit for CsProxy { + type Config = C::Config; + type FloorPlanner = C::FloorPlanner; + + fn without_witnesses(&self) -> Self { + CsProxy(PhantomData) + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + C::configure(meta) + } + + fn synthesize(&self, _: Self::Config, _: impl Layouter) -> Result<(), plonk::Error> { + Ok(()) + } + } + + let dummy_vk = vk + .is_none() + .then(|| keygen_vk(params, &CsProxy::(PhantomData)).unwrap()); + let protocol = compile( + params, + vk.or(dummy_vk.as_ref()).unwrap(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + let instances = ConcreteCircuit::num_instance() + .into_iter() + .map(|n| iter::repeat_with(|| Fr::random(OsRng)).take(n).collect()) + .collect(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + for _ in 0..protocol + .num_witness + .iter() + .chain(Some(&protocol.quotient.num_chunk())) + .sum::() + { + transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); + } + for _ in 0..protocol.evaluations.len() { + transcript.write_scalar(Fr::random(OsRng)).unwrap(); + } + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..Pcs::estimate_cost(&queries).num_commitment { + transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); + } + transcript.finalize() + }; + + Snark::new(protocol, instances, proof) + } +} + +mod application { + use super::*; + + #[derive(Clone, Default)] + pub struct Square(Fr); + + impl Circuit for Square { + type Config = Selector; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let q = meta.selector(); + let i = meta.instance_column(); + meta.create_gate("square", |meta| { + let q = meta.query_selector(q); + let [i, i_w] = [0, 1].map(|rotation| meta.query_instance(i, Rotation(rotation))); + Some(q * (i.clone() * i - i_w)) + }); + q + } + + fn synthesize( + &self, + q: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region(|| "", |mut region| q.enable(&mut region, 0)) + } + } + + impl CircuitExt for Square { + fn num_instance() -> Vec { + vec![2] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0, self.0.square()]] + } + } + + impl recursion::StateTransition for Square { + type Input = (); + + fn new(state: Fr) -> Self { + Self(state) + } + + fn state_transition(&self, _: Self::Input) -> Fr { + self.0.square() + } + } +} + +mod recursion { + use super::*; + use halo2_wrong_ecc::{ + integer::rns::Rns, + maingate::{ + MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, + RangeInstructions, RegionCtx, + }, + EccConfig, + }; + + type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; + type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + + pub trait StateTransition { + type Input; + + fn new(state: Fr) -> Self; + + fn state_transition(&self, input: Self::Input) -> Fr; + } + + fn succinct_verify<'a>( + svk: &Svk, + loader: &Rc>, + snark: &SnarkWitness, + expected_preprocessed_digest: Option>, + ) -> ( + Vec>>, + Vec>>>, + ) { + let protocol = if let Some(expected_preprocessed_digest) = expected_preprocessed_digest { + let protocol = snark.protocol.loaded_preprocessed_as_witness(loader); + let inputs = protocol + .preprocessed + .iter() + .flat_map(|preprocessed| { + let assigned = preprocessed.assigned(); + [assigned.x(), assigned.y()] + .map(|coordinate| loader.scalar_from_assigned(coordinate.native().clone())) + }) + .chain(protocol.transcript_initial_state.clone()) + .collect_vec(); + let preprocessed_digest = { + let mut hasher = Poseidon::new(loader, R_F, R_P); + hasher.update(&inputs); + hasher.squeeze() + }; + let expected_preprocessed_digest = + loader.scalar_from_assigned(expected_preprocessed_digest); + loader + .assert_eq("", &preprocessed_digest, &expected_preprocessed_digest) + .unwrap(); + protocol + } else { + snark.protocol.loaded(loader) + }; + + let instances = snark + .instances + .iter() + .map(|instances| { + instances + .iter() + .map(|instance| loader.assign_scalar(*instance)) + .collect_vec() + }) + .collect_vec(); + let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + let accumulators = Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap(); + + ( + instances + .into_iter() + .map(|instance| { + instance + .into_iter() + .map(|instance| instance.assigned()) + .collect() + }) + .collect(), + accumulators, + ) + } + + fn select_accumulator<'a>( + loader: &Rc>, + condition: &AssignedCell, + lhs: &KzgAccumulator>>, + rhs: &KzgAccumulator>>, + ) -> Result>>, Error> { + let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] + .iter() + .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) + .map(|(lhs, rhs)| { + loader + .ecc_chip() + .select(&mut loader.ctx_mut(), condition, lhs, rhs) + }) + .collect::, _>>()? + .try_into() + .unwrap(); + Ok(KzgAccumulator::new( + loader.ec_point_from_assigned(lhs), + loader.ec_point_from_assigned(rhs), + )) + } + + fn accumulate<'a>( + loader: &Rc>, + accumulators: Vec>>>, + as_proof: Value<&'_ [u8]>, + ) -> KzgAccumulator>> { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() + } + + #[derive(Clone)] + pub struct RecursionConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, + } + + impl RecursionConfig { + pub fn main_gate(&self) -> MainGate { + MainGate::new(self.main_gate_config.clone()) + } + + pub fn range_chip(&self) -> RangeChip { + RangeChip::new(self.range_config.clone()) + } + + pub fn ecc_chip(&self) -> BaseFieldEccChip { + BaseFieldEccChip::new(EccConfig::new( + self.range_config.clone(), + self.main_gate_config.clone(), + )) + } + } + + #[derive(Clone)] + pub struct RecursionCircuit { + svk: Svk, + default_accumulator: KzgAccumulator, + app: SnarkWitness, + previous: SnarkWitness, + round: usize, + instances: Vec, + as_proof: Value>, + } + + impl RecursionCircuit { + const PREPROCESSED_DIGEST_ROW: usize = 4 * LIMBS; + const INITIAL_STATE_ROW: usize = 4 * LIMBS + 1; + const STATE_ROW: usize = 4 * LIMBS + 2; + const ROUND_ROW: usize = 4 * LIMBS + 3; + + pub fn new( + params: &ParamsKZG, + app: Snark, + previous: Snark, + initial_state: Fr, + state: Fr, + round: usize, + ) -> Self { + let svk = params.get_g()[0].into(); + let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); + + let succinct_verify = |snark: &Snark| { + let mut transcript = + PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }; + + let accumulators = iter::empty() + .chain(succinct_verify(&app)) + .chain( + (round > 0) + .then_some(succinct_verify(&previous)) + .unwrap_or_else(|| vec![default_accumulator.clone(); 2]), + ) + .collect_vec(); + + let (accumulator, as_proof) = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .unwrap(); + (accumulator, transcript.finalize()) + }; + + let preprocessed_digest = { + let inputs = previous + .protocol + .preprocessed + .iter() + .flat_map(|preprocessed| { + let coordinates = preprocessed.coordinates().unwrap(); + [*coordinates.x(), *coordinates.y()] + }) + .map(fe_to_fe) + .chain(previous.protocol.transcript_initial_state) + .collect_vec(); + let mut poseidon = Poseidon::new(&NativeLoader, R_F, R_P); + poseidon.update(&inputs); + poseidon.squeeze() + }; + let instances = [ + accumulator.lhs.x, + accumulator.lhs.y, + accumulator.rhs.x, + accumulator.rhs.y, + ] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([ + preprocessed_digest, + initial_state, + state, + Fr::from(round as u64), + ]) + .collect(); + + Self { + svk, + default_accumulator, + app: app.into(), + previous: previous.into(), + round, + instances, + as_proof: Value::known(as_proof), + } + } + + fn initial_snark(params: &ParamsKZG, vk: Option<&VerifyingKey>) -> Snark { + let mut snark = gen_dummy_snark::(params, vk); + let g = params.get_g(); + snark.instances = vec![[g[1].x, g[1].y, g[0].x, g[0].y] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([Fr::zero(); 4]) + .collect_vec()]; + snark + } + + fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } + + fn load_default_accumulator<'a>( + &self, + loader: &Rc>, + ) -> Result>>, Error> { + let [lhs, rhs] = + [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { + loader + .ecc_chip() + .assign_constant(&mut loader.ctx_mut(), default) + .unwrap() + }); + Ok(KzgAccumulator::new( + loader.ec_point_from_assigned(lhs), + loader.ec_point_from_assigned(rhs), + )) + } + } + + impl Circuit for RecursionCircuit { + type Config = RecursionConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + default_accumulator: self.default_accumulator.clone(), + app: self.app.without_witnesses(), + previous: self.previous.without_witnesses(), + round: self.round, + instances: self.instances.clone(), + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + let main_gate_config = MainGate::::configure(meta); + let range_config = RangeChip::::configure( + meta, + &main_gate_config, + vec![BITS / LIMBS], + Rns::::construct().overflow_lengths(), + ); + RecursionConfig { + main_gate_config, + range_config, + } + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + + range_chip.load_table(&mut layouter)?; + + let [preprocessed_digest, initial_state, state, round, first_round, not_first_round] = + layouter.assign_region( + || "", + |region| { + let mut ctx = RegionCtx::new(region, 0); + let [preprocessed_digest, initial_state, state, round] = [ + self.instances[Self::PREPROCESSED_DIGEST_ROW], + self.instances[Self::INITIAL_STATE_ROW], + self.instances[Self::STATE_ROW], + self.instances[Self::ROUND_ROW], + ] + .map(|instance| { + main_gate + .assign_value(&mut ctx, Value::known(instance)) + .unwrap() + }); + let first_round = main_gate.is_zero(&mut ctx, &round)?; + let not_first_round = main_gate.not(&mut ctx, &first_round)?; + Ok([ + preprocessed_digest, + initial_state, + state, + round, + first_round, + not_first_round, + ]) + }, + )?; + + let (lhs, rhs, app_instances, previous_instances) = layouter.assign_region( + || "", + |region| { + let loader = Halo2Loader::new(config.ecc_chip(), RegionCtx::new(region, 0)); + let (app_instances, app_accumulators) = + succinct_verify(&self.svk, &loader, &self.app, None); + let (previous_instances, previous_accumulators) = succinct_verify( + &self.svk, + &loader, + &self.previous, + Some(preprocessed_digest.clone()), + ); + + let default_accmulator = self.load_default_accumulator(&loader)?; + let previous_accumulators = previous_accumulators + .iter() + .map(|previous_accumulator| { + select_accumulator( + &loader, + &first_round, + &default_accmulator, + previous_accumulator, + ) + }) + .collect::, Error>>()?; + + let KzgAccumulator { lhs, rhs } = accumulate( + &loader, + [app_accumulators, previous_accumulators].concat(), + self.as_proof(), + ); + + Ok(( + lhs.assigned(), + rhs.assigned(), + app_instances, + previous_instances, + )) + }, + )?; + + layouter.assign_region( + || "", + |region| { + let mut ctx = RegionCtx::new(region, 0); + for (lhs, rhs) in [ + // Propagate preprocessed_digest + ( + &main_gate.mul(&mut ctx, &preprocessed_digest, ¬_first_round)?, + &previous_instances[0][Self::PREPROCESSED_DIGEST_ROW], + ), + // Propagate initial_state + ( + &main_gate.mul(&mut ctx, &initial_state, ¬_first_round)?, + &previous_instances[0][Self::INITIAL_STATE_ROW], + ), + // Verify initial_state is same as the first application snark + ( + &main_gate.mul(&mut ctx, &initial_state, &first_round)?, + &main_gate.mul(&mut ctx, &app_instances[0][0], &first_round)?, + ), + // Verify current state is same as the current application snark + (&state, &app_instances[0][1]), + // Verify previous state is same as the current application snark + ( + &main_gate.mul(&mut ctx, &app_instances[0][0], ¬_first_round)?, + &previous_instances[0][Self::STATE_ROW], + ), + // Verify round is increased by 1 when not at first round + ( + &round, + &main_gate.add( + &mut ctx, + ¬_first_round, + &previous_instances[0][Self::ROUND_ROW], + )?, + ), + ] { + main_gate.assert_equal(&mut ctx, lhs, rhs)?; + } + Ok(()) + }, + )?; + + for (row, limb) in [lhs.x(), lhs.y(), rhs.x(), rhs.y()] + .into_iter() + .flat_map(|coordinate| coordinate.limbs()) + .map_into() + .chain([preprocessed_digest, initial_state, state, round]) + .enumerate() + { + main_gate.expose_public(layouter.namespace(|| ""), limb, row)?; + } + + Ok(()) + } + } + + impl CircuitExt for RecursionCircuit { + fn num_instance() -> Vec { + // [..lhs, ..rhs, preprocessed_digest, initial_state, state, round] + vec![4 * LIMBS + 4] + } + + fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + fn accumulator_indices() -> Option> { + Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) + } + } + + pub fn gen_recursion_pk>( + app_params: &ParamsKZG, + recursion_params: &ParamsKZG, + app_vk: &VerifyingKey, + ) -> ProvingKey { + let recursion = RecursionCircuit::new( + recursion_params, + gen_dummy_snark::(app_params, Some(app_vk)), + RecursionCircuit::initial_snark(recursion_params, None), + Fr::zero(), + Fr::zero(), + 0, + ); + gen_pk(recursion_params, &recursion) + } + + pub fn gen_recursion_snark + StateTransition>( + app_params: &ParamsKZG, + recursion_params: &ParamsKZG, + app_pk: &ProvingKey, + recursion_pk: &ProvingKey, + initial_state: Fr, + inputs: Vec, + ) -> (Fr, Snark) { + let mut state = initial_state; + let mut app = ConcreteCircuit::new(state); + let mut previous = + RecursionCircuit::initial_snark(recursion_params, Some(recursion_pk.get_vk())); + for (round, input) in inputs.into_iter().enumerate() { + state = app.state_transition(input); + let recursion = RecursionCircuit::new( + recursion_params, + gen_snark(app_params, app_pk, app), + previous, + initial_state, + state, + round, + ); + previous = gen_snark(recursion_params, recursion_pk, recursion); + app = ConcreteCircuit::new(state); + } + (state, previous) + } +} + +fn main() { + let app_params = gen_srs(3); + let recursion_params = gen_srs(22); + + let app_pk = gen_pk(&app_params, &application::Square::default()); + let recursion_pk = recursion::gen_recursion_pk::( + &app_params, + &recursion_params, + app_pk.get_vk(), + ); + + let num_round = 3; + let (final_state, snark) = recursion::gen_recursion_snark::( + &app_params, + &recursion_params, + &app_pk, + &recursion_pk, + Fr::from(2), + vec![(); num_round], + ); + assert_eq!(final_state, Fr::from(256)); + + let accept = { + let svk = recursion_params.get_g()[0].into(); + let dk = (recursion_params.g2(), recursion_params.s_g2()).into(); + let mut transcript = PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &snark.protocol, &snark.instances, &proof).unwrap() + }; + assert!(accept) +} From 3864980cf7104a192c25ac754797054b7fdebf98 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Tue, 1 Nov 2022 00:21:56 -0700 Subject: [PATCH 13/28] chore: update for compatibility with halo2-lib --- src/loader/halo2/loader.rs | 14 +++++++------- src/system/halo2/aggregation.rs | 10 +--------- src/system/halo2/test/kzg/halo2.rs | 10 +--------- 3 files changed, 9 insertions(+), 25 deletions(-) diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index d5bcfb26..103663f6 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -159,11 +159,9 @@ where let assigned = self.ecc_chip.assign_point(&mut self.ctx_mut(), ec_point).unwrap(); let is_on_curve_or_infinity = self.ecc_chip.is_on_curve_or_infinity::(&mut self.ctx_mut(), &assigned).unwrap(); - self.gate().assert_is_const( - &mut self.ctx_mut(), - &is_on_curve_or_infinity, - C::Scalar::one(), - ); + self.gate() + .assert_is_const(&mut self.ctx_mut(), &is_on_curve_or_infinity, C::Scalar::one()) + .unwrap(); self.ec_point(assigned) } @@ -335,7 +333,9 @@ where let is_zero = RangeInstructions::is_zero(self.range(), &mut self.ctx_mut(), assigned) .unwrap(); - self.gate().assert_is_const(&mut self.ctx_mut(), &is_zero, C::Scalar::zero()); + self.gate() + .assert_is_const(&mut self.ctx_mut(), &is_zero, C::Scalar::zero()) + .unwrap(); GateInstructions::div_unsafe( self.gate(), &mut self.ctx_mut(), @@ -842,7 +842,7 @@ impl<'a, 'b, C: CurveAffine> ScalarLoader for Rc Existing(assigned), })); b.extend(values.iter().map(|(c, _)| Constant(*c))); - let (_, _, sum) = self.gate().inner_product(&mut self.ctx_mut(), &a, &b).unwrap(); + let (_, _, sum) = self.gate().inner_product(&mut self.ctx_mut(), a, b).unwrap(); self.scalar(Value::Assigned(sum)) } diff --git a/src/system/halo2/aggregation.rs b/src/system/halo2/aggregation.rs index 40a154f7..46d2ace3 100644 --- a/src/system/halo2/aggregation.rs +++ b/src/system/halo2/aggregation.rs @@ -388,15 +388,7 @@ impl AggregationCircuit { first_pass = false; return Ok(()); } - let ctx = Context::new( - region, - ContextParams { - num_advice: vec![( - config.base_field_config.range.context_id.clone(), - config.base_field_config.range.gate.num_advice, - )], - }, - ); + let ctx = config.base_field_config.new_context(region); let loader = Halo2Loader::new(&config.base_field_config, ctx); let instances = aggregate( diff --git a/src/system/halo2/test/kzg/halo2.rs b/src/system/halo2/test/kzg/halo2.rs index e0cf3339..6aa4b543 100644 --- a/src/system/halo2/test/kzg/halo2.rs +++ b/src/system/halo2/test/kzg/halo2.rs @@ -303,15 +303,7 @@ impl Circuit for Accumulation { first_pass = false; return Ok(()); } - let ctx = Context::new( - region, - ContextParams { - num_advice: vec![( - config.base_field_config.range.context_id.clone(), - config.base_field_config.range.gate.num_advice, - )], - }, - ); + let ctx = config.base_field_config.new_context(region); let loader = Halo2Loader::new(&config.base_field_config, ctx); let KzgAccumulator { lhs, rhs } = From 6778bfad85880307a0a577e194f7b686b17cf558 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Tue, 1 Nov 2022 13:07:59 -0700 Subject: [PATCH 14/28] poseidon: switch to rate = 2, `t = 3` for faster proving time --- src/system/halo2/aggregation.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/system/halo2/aggregation.rs b/src/system/halo2/aggregation.rs index 46d2ace3..a3b09c15 100644 --- a/src/system/halo2/aggregation.rs +++ b/src/system/halo2/aggregation.rs @@ -53,10 +53,10 @@ use std::{ rc::Rc, }; -pub const T: usize = 5; -pub const RATE: usize = 4; +pub const T: usize = 3; +pub const RATE: usize = 2; pub const R_F: usize = 8; -pub const R_P: usize = 60; +pub const R_P: usize = 57; pub type Halo2Loader<'a, 'b> = loader::halo2::Halo2Loader<'a, 'b, G1Affine>; pub type PoseidonTranscript = From 0abbc07c2745131debe4b7458e5ec42d1dd826ae Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Thu, 3 Nov 2022 14:10:23 -0700 Subject: [PATCH 15/28] update: remove unnecessary trait `where` clause --- src/loader/halo2/shim.rs | 5 +---- src/system/halo2/transcript/halo2.rs | 5 +---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 17afcaa3..3bbac5a1 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -271,10 +271,7 @@ mod halo2_lib { } } - impl<'a, 'b, C: CurveAffine> EccInstructions<'a, C> for BaseFieldEccChip<'b, C> - where - for<'c, 'd> &'c C::CurveExt: std::ops::Add<&'d C::CurveExt, Output = C::CurveExt>, - { + impl<'a, 'b, C: CurveAffine> EccInstructions<'a, C> for BaseFieldEccChip<'b, C> { type Context = halo2_base::Context<'a, C::Scalar>; type ScalarChip = FlexGateConfig; type AssignedEcPoint = AssignedEcPoint; diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs index 1988f0f8..faa29207 100644 --- a/src/system/halo2/transcript/halo2.rs +++ b/src/system/halo2/transcript/halo2.rs @@ -420,10 +420,7 @@ mod halo2_lib { use halo2_curves::CurveAffine; use halo2_ecc::ecc::BaseFieldEccChip; - impl<'a, 'b, C: CurveAffine> EncodeNative<'a, C, C::Scalar> for BaseFieldEccChip<'b, C> - where - for<'c, 'd> &'c C::CurveExt: std::ops::Add<&'d C::CurveExt, Output = C::CurveExt>, - { + impl<'a, 'b, C: CurveAffine> EncodeNative<'a, C, C::Scalar> for BaseFieldEccChip<'b, C> { fn encode_native( &self, _: &mut Self::Context, From 179a31acd3a89ebaa878b464fc84fc64cd18aa37 Mon Sep 17 00:00:00 2001 From: han0110 Date: Fri, 28 Oct 2022 18:48:10 +0800 Subject: [PATCH 16/28] feat: add example `recursion` --- Cargo.toml | 4 + examples/recursion.rs | 873 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 877 insertions(+) create mode 100644 examples/recursion.rs diff --git a/Cargo.toml b/Cargo.toml index 51935591..06f289a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,3 +61,7 @@ required-features = ["loader_evm", "system_halo2"] [[example]] name = "evm-verifier-with-accumulator" required-features = ["loader_halo2", "loader_evm", "system_halo2"] + +[[example]] +name = "recursion" +required-features = ["loader_halo2", "system_halo2"] diff --git a/examples/recursion.rs b/examples/recursion.rs new file mode 100644 index 00000000..5c34d663 --- /dev/null +++ b/examples/recursion.rs @@ -0,0 +1,873 @@ +#![allow(clippy::type_complexity)] + +use common::*; +use halo2_curves::{ + bn256::{Bn256, Fq, Fr, G1Affine}, + group::ff::Field, + CurveAffine, +}; +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{ + self, create_proof, keygen_pk, keygen_vk, Circuit, ConstraintSystem, Error, ProvingKey, + Selector, VerifyingKey, + }, + poly::{ + commitment::ParamsProver, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + Rotation, VerificationStrategy, + }, +}; +use itertools::Itertools; +use plonk_verifier::{ + loader::{self, native::NativeLoader, ScalarLoader}, + pcs::{ + kzg::{Gwc19, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, + AccumulationScheme, AccumulationSchemeProver, + }, + system::{ + self, + halo2::{compile, Config}, + }, + util::{ + arithmetic::{fe_to_fe, fe_to_limbs}, + hash, + }, + verifier::{self, PlonkProof, PlonkVerifier}, + Protocol, +}; +use rand::{rngs::OsRng, SeedableRng}; +use rand_chacha::ChaCha20Rng; +use std::{fs, iter, marker::PhantomData, rc::Rc}; + +const LIMBS: usize = 4; +const BITS: usize = 68; +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +type Pcs = Kzg; +type Svk = KzgSuccinctVerifyingKey; +type As = KzgAs; +type Plonk = verifier::Plonk>; +type Poseidon = hash::Poseidon; +type PoseidonTranscript = + system::halo2::transcript::halo2::PoseidonTranscript; + +mod common { + use super::*; + use halo2_proofs::poly::commitment::Params; + use plonk_verifier::{cost::CostEstimation, util::transcript::TranscriptWrite}; + + pub struct Snark { + pub protocol: Protocol, + pub instances: Vec>, + pub proof: Vec, + } + + impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + Self { + protocol, + instances, + proof, + } + } + } + + impl From for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } + } + + #[derive(Clone)] + pub struct SnarkWitness { + pub protocol: Protocol, + pub instances: Vec>>, + pub proof: Value>, + } + + impl SnarkWitness { + pub fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + pub fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } + } + + pub trait CircuitExt: Circuit { + fn num_instance() -> Vec; + + fn instances(&self) -> Vec>; + + fn accumulator_indices() -> Option> { + None + } + } + + pub fn gen_srs(k: u32) -> ParamsKZG { + let path = format!("./examples/k-{}.srs", k); + match fs::File::open(path.as_str()) { + Ok(mut file) => ParamsKZG::read(&mut file).unwrap(), + Err(_) => { + let params = + ParamsKZG::::setup(k, ChaCha20Rng::from_seed(Default::default())); + params.write(&mut fs::File::create(path).unwrap()).unwrap(); + params + } + } + } + + pub fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() + } + + pub fn gen_proof>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + ) -> Vec { + if params.k() > 3 { + MockProver::run(params.k(), &circuit, instances.clone()) + .unwrap() + .assert_satisfied(); + } + + let instances = instances + .iter() + .map(|instances| instances.as_slice()) + .collect_vec(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + create_proof::< + KZGCommitmentScheme, + ProverGWC<_>, + _, + _, + PoseidonTranscript, + _, + >( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = PoseidonTranscript::::new(proof.as_slice()); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + plonk::verify_proof::<_, VerifierGWC<_>, _, PoseidonTranscript, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof + } + + pub fn gen_snark>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + ) -> Snark { + let protocol = compile( + params, + pk.get_vk(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + + let instances = circuit.instances(); + let proof = gen_proof(params, pk, circuit, instances.clone()); + + Snark::new(protocol, instances, proof) + } + + pub fn gen_dummy_snark>( + params: &ParamsKZG, + vk: Option<&VerifyingKey>, + ) -> Snark { + struct CsProxy(PhantomData<(F, C)>); + + impl> Circuit for CsProxy { + type Config = C::Config; + type FloorPlanner = C::FloorPlanner; + + fn without_witnesses(&self) -> Self { + CsProxy(PhantomData) + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + C::configure(meta) + } + + fn synthesize(&self, _: Self::Config, _: impl Layouter) -> Result<(), plonk::Error> { + Ok(()) + } + } + + let dummy_vk = vk + .is_none() + .then(|| keygen_vk(params, &CsProxy::(PhantomData)).unwrap()); + let protocol = compile( + params, + vk.or(dummy_vk.as_ref()).unwrap(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + let instances = ConcreteCircuit::num_instance() + .into_iter() + .map(|n| iter::repeat_with(|| Fr::random(OsRng)).take(n).collect()) + .collect(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + for _ in 0..protocol + .num_witness + .iter() + .chain(Some(&protocol.quotient.num_chunk())) + .sum::() + { + transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); + } + for _ in 0..protocol.evaluations.len() { + transcript.write_scalar(Fr::random(OsRng)).unwrap(); + } + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..Pcs::estimate_cost(&queries).num_commitment { + transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); + } + transcript.finalize() + }; + + Snark::new(protocol, instances, proof) + } +} + +mod application { + use super::*; + + #[derive(Clone, Default)] + pub struct Square(Fr); + + impl Circuit for Square { + type Config = Selector; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let q = meta.selector(); + let i = meta.instance_column(); + meta.create_gate("square", |meta| { + let q = meta.query_selector(q); + let [i, i_w] = [0, 1].map(|rotation| meta.query_instance(i, Rotation(rotation))); + Some(q * (i.clone() * i - i_w)) + }); + q + } + + fn synthesize( + &self, + q: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region(|| "", |mut region| q.enable(&mut region, 0)) + } + } + + impl CircuitExt for Square { + fn num_instance() -> Vec { + vec![2] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0, self.0.square()]] + } + } + + impl recursion::StateTransition for Square { + type Input = (); + + fn new(state: Fr) -> Self { + Self(state) + } + + fn state_transition(&self, _: Self::Input) -> Fr { + self.0.square() + } + } +} + +mod recursion { + use super::*; + use halo2_wrong_ecc::{ + integer::rns::Rns, + maingate::{ + MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, + RangeInstructions, RegionCtx, + }, + EccConfig, + }; + + type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; + type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + + pub trait StateTransition { + type Input; + + fn new(state: Fr) -> Self; + + fn state_transition(&self, input: Self::Input) -> Fr; + } + + fn succinct_verify<'a>( + svk: &Svk, + loader: &Rc>, + snark: &SnarkWitness, + expected_preprocessed_digest: Option>, + ) -> ( + Vec>>, + Vec>>>, + ) { + let protocol = if let Some(expected_preprocessed_digest) = expected_preprocessed_digest { + let protocol = snark.protocol.loaded_preprocessed_as_witness(loader); + let inputs = protocol + .preprocessed + .iter() + .flat_map(|preprocessed| { + let assigned = preprocessed.assigned(); + [assigned.x(), assigned.y()] + .map(|coordinate| loader.scalar_from_assigned(coordinate.native().clone())) + }) + .chain(protocol.transcript_initial_state.clone()) + .collect_vec(); + let preprocessed_digest = { + let mut hasher = Poseidon::new(loader, R_F, R_P); + hasher.update(&inputs); + hasher.squeeze() + }; + let expected_preprocessed_digest = + loader.scalar_from_assigned(expected_preprocessed_digest); + loader + .assert_eq("", &preprocessed_digest, &expected_preprocessed_digest) + .unwrap(); + protocol + } else { + snark.protocol.loaded(loader) + }; + + let instances = snark + .instances + .iter() + .map(|instances| { + instances + .iter() + .map(|instance| loader.assign_scalar(*instance)) + .collect_vec() + }) + .collect_vec(); + let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + let accumulators = Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap(); + + ( + instances + .into_iter() + .map(|instance| { + instance + .into_iter() + .map(|instance| instance.assigned()) + .collect() + }) + .collect(), + accumulators, + ) + } + + fn select_accumulator<'a>( + loader: &Rc>, + condition: &AssignedCell, + lhs: &KzgAccumulator>>, + rhs: &KzgAccumulator>>, + ) -> Result>>, Error> { + let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] + .iter() + .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) + .map(|(lhs, rhs)| { + loader + .ecc_chip() + .select(&mut loader.ctx_mut(), condition, lhs, rhs) + }) + .collect::, _>>()? + .try_into() + .unwrap(); + Ok(KzgAccumulator::new( + loader.ec_point_from_assigned(lhs), + loader.ec_point_from_assigned(rhs), + )) + } + + fn accumulate<'a>( + loader: &Rc>, + accumulators: Vec>>>, + as_proof: Value<&'_ [u8]>, + ) -> KzgAccumulator>> { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() + } + + #[derive(Clone)] + pub struct RecursionConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, + } + + impl RecursionConfig { + pub fn main_gate(&self) -> MainGate { + MainGate::new(self.main_gate_config.clone()) + } + + pub fn range_chip(&self) -> RangeChip { + RangeChip::new(self.range_config.clone()) + } + + pub fn ecc_chip(&self) -> BaseFieldEccChip { + BaseFieldEccChip::new(EccConfig::new( + self.range_config.clone(), + self.main_gate_config.clone(), + )) + } + } + + #[derive(Clone)] + pub struct RecursionCircuit { + svk: Svk, + default_accumulator: KzgAccumulator, + app: SnarkWitness, + previous: SnarkWitness, + round: usize, + instances: Vec, + as_proof: Value>, + } + + impl RecursionCircuit { + const PREPROCESSED_DIGEST_ROW: usize = 4 * LIMBS; + const INITIAL_STATE_ROW: usize = 4 * LIMBS + 1; + const STATE_ROW: usize = 4 * LIMBS + 2; + const ROUND_ROW: usize = 4 * LIMBS + 3; + + pub fn new( + params: &ParamsKZG, + app: Snark, + previous: Snark, + initial_state: Fr, + state: Fr, + round: usize, + ) -> Self { + let svk = params.get_g()[0].into(); + let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); + + let succinct_verify = |snark: &Snark| { + let mut transcript = + PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }; + + let accumulators = iter::empty() + .chain(succinct_verify(&app)) + .chain( + (round > 0) + .then_some(succinct_verify(&previous)) + .unwrap_or_else(|| vec![default_accumulator.clone(); 2]), + ) + .collect_vec(); + + let (accumulator, as_proof) = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .unwrap(); + (accumulator, transcript.finalize()) + }; + + let preprocessed_digest = { + let inputs = previous + .protocol + .preprocessed + .iter() + .flat_map(|preprocessed| { + let coordinates = preprocessed.coordinates().unwrap(); + [*coordinates.x(), *coordinates.y()] + }) + .map(fe_to_fe) + .chain(previous.protocol.transcript_initial_state) + .collect_vec(); + let mut poseidon = Poseidon::new(&NativeLoader, R_F, R_P); + poseidon.update(&inputs); + poseidon.squeeze() + }; + let instances = [ + accumulator.lhs.x, + accumulator.lhs.y, + accumulator.rhs.x, + accumulator.rhs.y, + ] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([ + preprocessed_digest, + initial_state, + state, + Fr::from(round as u64), + ]) + .collect(); + + Self { + svk, + default_accumulator, + app: app.into(), + previous: previous.into(), + round, + instances, + as_proof: Value::known(as_proof), + } + } + + fn initial_snark(params: &ParamsKZG, vk: Option<&VerifyingKey>) -> Snark { + let mut snark = gen_dummy_snark::(params, vk); + let g = params.get_g(); + snark.instances = vec![[g[1].x, g[1].y, g[0].x, g[0].y] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([Fr::zero(); 4]) + .collect_vec()]; + snark + } + + fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } + + fn load_default_accumulator<'a>( + &self, + loader: &Rc>, + ) -> Result>>, Error> { + let [lhs, rhs] = + [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { + loader + .ecc_chip() + .assign_constant(&mut loader.ctx_mut(), default) + .unwrap() + }); + Ok(KzgAccumulator::new( + loader.ec_point_from_assigned(lhs), + loader.ec_point_from_assigned(rhs), + )) + } + } + + impl Circuit for RecursionCircuit { + type Config = RecursionConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + default_accumulator: self.default_accumulator.clone(), + app: self.app.without_witnesses(), + previous: self.previous.without_witnesses(), + round: self.round, + instances: self.instances.clone(), + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + let main_gate_config = MainGate::::configure(meta); + let range_config = RangeChip::::configure( + meta, + &main_gate_config, + vec![BITS / LIMBS], + Rns::::construct().overflow_lengths(), + ); + RecursionConfig { + main_gate_config, + range_config, + } + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + + range_chip.load_table(&mut layouter)?; + + let [preprocessed_digest, initial_state, state, round, first_round, not_first_round] = + layouter.assign_region( + || "", + |region| { + let mut ctx = RegionCtx::new(region, 0); + let [preprocessed_digest, initial_state, state, round] = [ + self.instances[Self::PREPROCESSED_DIGEST_ROW], + self.instances[Self::INITIAL_STATE_ROW], + self.instances[Self::STATE_ROW], + self.instances[Self::ROUND_ROW], + ] + .map(|instance| { + main_gate + .assign_value(&mut ctx, Value::known(instance)) + .unwrap() + }); + let first_round = main_gate.is_zero(&mut ctx, &round)?; + let not_first_round = main_gate.not(&mut ctx, &first_round)?; + Ok([ + preprocessed_digest, + initial_state, + state, + round, + first_round, + not_first_round, + ]) + }, + )?; + + let (lhs, rhs, app_instances, previous_instances) = layouter.assign_region( + || "", + |region| { + let loader = Halo2Loader::new(config.ecc_chip(), RegionCtx::new(region, 0)); + let (app_instances, app_accumulators) = + succinct_verify(&self.svk, &loader, &self.app, None); + let (previous_instances, previous_accumulators) = succinct_verify( + &self.svk, + &loader, + &self.previous, + Some(preprocessed_digest.clone()), + ); + + let default_accmulator = self.load_default_accumulator(&loader)?; + let previous_accumulators = previous_accumulators + .iter() + .map(|previous_accumulator| { + select_accumulator( + &loader, + &first_round, + &default_accmulator, + previous_accumulator, + ) + }) + .collect::, Error>>()?; + + let KzgAccumulator { lhs, rhs } = accumulate( + &loader, + [app_accumulators, previous_accumulators].concat(), + self.as_proof(), + ); + + Ok(( + lhs.assigned(), + rhs.assigned(), + app_instances, + previous_instances, + )) + }, + )?; + + layouter.assign_region( + || "", + |region| { + let mut ctx = RegionCtx::new(region, 0); + for (lhs, rhs) in [ + // Propagate preprocessed_digest + ( + &main_gate.mul(&mut ctx, &preprocessed_digest, ¬_first_round)?, + &previous_instances[0][Self::PREPROCESSED_DIGEST_ROW], + ), + // Propagate initial_state + ( + &main_gate.mul(&mut ctx, &initial_state, ¬_first_round)?, + &previous_instances[0][Self::INITIAL_STATE_ROW], + ), + // Verify initial_state is same as the first application snark + ( + &main_gate.mul(&mut ctx, &initial_state, &first_round)?, + &main_gate.mul(&mut ctx, &app_instances[0][0], &first_round)?, + ), + // Verify current state is same as the current application snark + (&state, &app_instances[0][1]), + // Verify previous state is same as the current application snark + ( + &main_gate.mul(&mut ctx, &app_instances[0][0], ¬_first_round)?, + &previous_instances[0][Self::STATE_ROW], + ), + // Verify round is increased by 1 when not at first round + ( + &round, + &main_gate.add( + &mut ctx, + ¬_first_round, + &previous_instances[0][Self::ROUND_ROW], + )?, + ), + ] { + main_gate.assert_equal(&mut ctx, lhs, rhs)?; + } + Ok(()) + }, + )?; + + for (row, limb) in [lhs.x(), lhs.y(), rhs.x(), rhs.y()] + .into_iter() + .flat_map(|coordinate| coordinate.limbs()) + .map_into() + .chain([preprocessed_digest, initial_state, state, round]) + .enumerate() + { + main_gate.expose_public(layouter.namespace(|| ""), limb, row)?; + } + + Ok(()) + } + } + + impl CircuitExt for RecursionCircuit { + fn num_instance() -> Vec { + // [..lhs, ..rhs, preprocessed_digest, initial_state, state, round] + vec![4 * LIMBS + 4] + } + + fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + fn accumulator_indices() -> Option> { + Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) + } + } + + pub fn gen_recursion_pk>( + app_params: &ParamsKZG, + recursion_params: &ParamsKZG, + app_vk: &VerifyingKey, + ) -> ProvingKey { + let recursion = RecursionCircuit::new( + recursion_params, + gen_dummy_snark::(app_params, Some(app_vk)), + RecursionCircuit::initial_snark(recursion_params, None), + Fr::zero(), + Fr::zero(), + 0, + ); + gen_pk(recursion_params, &recursion) + } + + pub fn gen_recursion_snark + StateTransition>( + app_params: &ParamsKZG, + recursion_params: &ParamsKZG, + app_pk: &ProvingKey, + recursion_pk: &ProvingKey, + initial_state: Fr, + inputs: Vec, + ) -> (Fr, Snark) { + let mut state = initial_state; + let mut app = ConcreteCircuit::new(state); + let mut previous = + RecursionCircuit::initial_snark(recursion_params, Some(recursion_pk.get_vk())); + for (round, input) in inputs.into_iter().enumerate() { + state = app.state_transition(input); + let recursion = RecursionCircuit::new( + recursion_params, + gen_snark(app_params, app_pk, app), + previous, + initial_state, + state, + round, + ); + previous = gen_snark(recursion_params, recursion_pk, recursion); + app = ConcreteCircuit::new(state); + } + (state, previous) + } +} + +fn main() { + let app_params = gen_srs(3); + let recursion_params = gen_srs(22); + + let app_pk = gen_pk(&app_params, &application::Square::default()); + let recursion_pk = recursion::gen_recursion_pk::( + &app_params, + &recursion_params, + app_pk.get_vk(), + ); + + let num_round = 3; + let (final_state, snark) = recursion::gen_recursion_snark::( + &app_params, + &recursion_params, + &app_pk, + &recursion_pk, + Fr::from(2), + vec![(); num_round], + ); + assert_eq!(final_state, Fr::from(256)); + + let accept = { + let svk = recursion_params.get_g()[0].into(); + let dk = (recursion_params.g2(), recursion_params.s_g2()).into(); + let mut transcript = PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &snark.protocol, &snark.instances, &proof).unwrap() + }; + assert!(accept) +} From 25dbaf524e7a81f52e65aa6ae45a8f1e7ac509c3 Mon Sep 17 00:00:00 2001 From: Han Date: Tue, 8 Nov 2022 10:21:24 -0800 Subject: [PATCH 17/28] General refactor for further integration (#13) * feat: remove dev-dependency `foundry` and vendor necessary part of it * refactor: simplify traits and remove unused stuff * refactor: much less clone * feat: generalized `AccumulatorEncoding` for `EccInstructions` --- Cargo.toml | 15 +- examples/evm-verifier-with-accumulator.rs | 48 +- examples/evm-verifier.rs | 13 +- src/loader.rs | 13 +- src/loader/evm.rs | 4 +- src/loader/evm/loader.rs | 11 +- src/loader/evm/test.rs | 22 +- src/loader/evm/test/tui.rs | 57 +- src/loader/evm/util.rs | 4 + src/loader/evm/util/executor.rs | 868 ++++++++++++++++++++++ src/loader/halo2.rs | 8 +- src/loader/halo2/loader.rs | 221 +++--- src/loader/halo2/shim.rs | 168 +++-- src/loader/native.rs | 3 +- src/pcs.rs | 4 +- src/pcs/kzg.rs | 3 + src/pcs/kzg/accumulation.rs | 14 +- src/pcs/kzg/accumulator.rs | 132 +++- src/pcs/kzg/decider.rs | 2 +- src/pcs/kzg/multiopen/bdfg21.rs | 35 +- src/pcs/kzg/multiopen/gwc19.rs | 20 +- src/system/halo2.rs | 22 +- src/system/halo2/test/kzg.rs | 4 +- src/system/halo2/test/kzg/halo2.rs | 33 +- src/system/halo2/transcript/halo2.rs | 191 ++--- src/util/arithmetic.rs | 17 +- src/util/msm.rs | 62 +- src/util/protocol.rs | 8 +- src/verifier/plonk.rs | 36 +- 29 files changed, 1488 insertions(+), 550 deletions(-) create mode 100644 src/loader/evm/util/executor.rs diff --git a/Cargo.toml b/Cargo.toml index 803beade..e5bbb0d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,14 +10,18 @@ num-bigint = "0.4.3" num-integer = "0.1.45" num-traits = "0.2.15" rand = "0.8" +hex = "0.4" halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } # system_halo2 halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_10_22", optional = true } # loader_evm -ethereum_types = { package = "ethereum-types", version = "0.13.1", default-features = false, features = ["std"], optional = true } -sha3 = { version = "0.10.1", optional = true } +ethereum_types = { package = "ethereum-types", version = "0.13", default-features = false, features = ["std"], optional = true } +sha3 = { version = "0.10", optional = true } +revm = { version = "2.1.0", optional = true } +bytes = { version = "1.2", optional = true } +rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } # loader_halo2 halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc", optional = true } @@ -31,14 +35,13 @@ paste = "1.0.7" halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc" } # loader_evm -foundry_evm = { git = "https://github.com/foundry-rs/foundry", package = "foundry-evm", rev = "6b1ee60e" } -crossterm = { version = "0.22.1" } -tui = { version = "0.16.0", default-features = false, features = ["crossterm"] } +crossterm = { version = "0.25" } +tui = { version = "0.19", default-features = false, features = ["crossterm"] } [features] default = ["loader_evm", "loader_halo2", "system_halo2"] -loader_evm = ["dep:ethereum_types", "dep:sha3"] +loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] loader_halo2 = ["dep:halo2_proofs", "dep:halo2_wrong_ecc", "dep:poseidon"] system_halo2 = ["dep:halo2_proofs"] diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 66fa8ddf..10a7b269 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -1,5 +1,4 @@ use ethereum_types::Address; -use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2_proofs::{ dev::MockProver, @@ -18,7 +17,7 @@ use halo2_proofs::{ use itertools::Itertools; use plonk_verifier::{ loader::{ - evm::{encode_calldata, EvmLoader}, + evm::{encode_calldata, EvmLoader, ExecutorBuilder}, native::NativeLoader, }, pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, @@ -167,7 +166,7 @@ mod aggregation { use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, - plonk::{self, Circuit, ConstraintSystem}, + plonk::{self, Circuit, ConstraintSystem, Error}, poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, }; use halo2_wrong_ecc::{ @@ -182,7 +181,7 @@ mod aggregation { use plonk_verifier::{ loader::{self, native::NativeLoader}, pcs::{ - kzg::{KzgAccumulator, KzgSuccinctVerifyingKey}, + kzg::{KzgAccumulator, KzgSuccinctVerifyingKey, LimbsEncodingInstructions}, AccumulationScheme, AccumulationSchemeProver, }, system, @@ -191,7 +190,7 @@ mod aggregation { Protocol, }; use rand::rngs::OsRng; - use std::{iter, rc::Rc}; + use std::rc::Rc; const T: usize = 5; const RATE: usize = 4; @@ -434,28 +433,33 @@ mod aggregation { range_chip.load_table(&mut layouter)?; - let (lhs, rhs) = layouter.assign_region( + let accumulator_limbs = layouter.assign_region( || "", |region| { let ctx = RegionCtx::new(region, 0); let ecc_chip = config.ecc_chip(); let loader = Halo2Loader::new(ecc_chip, ctx); - let KzgAccumulator { lhs, rhs } = - aggregate(&self.svk, &loader, &self.snarks, self.as_proof()); + let accumulator = aggregate(&self.svk, &loader, &self.snarks, self.as_proof()); - Ok((lhs.assigned(), rhs.assigned())) + let accumulator_limbs = [accumulator.lhs, accumulator.rhs] + .iter() + .map(|ec_point| { + loader.ecc_chip().assign_ec_point_to_limbs( + &mut loader.ctx_mut(), + ec_point.assigned(), + ) + }) + .collect::, Error>>()? + .into_iter() + .flatten(); + + Ok(accumulator_limbs) }, )?; - for (limb, row) in iter::empty() - .chain(lhs.x().limbs()) - .chain(lhs.y().limbs()) - .chain(rhs.x().limbs()) - .chain(rhs.y().limbs()) - .zip(0..) - { - main_gate.expose_public(layouter.namespace(|| ""), limb.into(), row)?; + for (row, limb) in accumulator_limbs.enumerate() { + main_gate.expose_public(layouter.namespace(|| ""), limb, row)?; } Ok(()) @@ -574,16 +578,14 @@ fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) let success = { let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) - .build(Backend::new(MultiFork::new().0, None)); + .build(); let caller = Address::from_low_u64_be(0xfe); let verifier = evm - .deploy(caller, deployment_code.into(), 0.into(), None) - .unwrap() - .address; - let result = evm - .call_raw(caller, verifier, calldata.into(), 0.into()) + .deploy(caller, deployment_code.into(), 0.into()) + .address .unwrap(); + let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); dbg!(result.gas_used); diff --git a/examples/evm-verifier.rs b/examples/evm-verifier.rs index 9ed70b36..4ff5c682 100644 --- a/examples/evm-verifier.rs +++ b/examples/evm-verifier.rs @@ -1,5 +1,4 @@ use ethereum_types::Address; -use foundry_evm::executor::{fork::MultiFork, Backend, ExecutorBuilder}; use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, @@ -21,7 +20,7 @@ use halo2_proofs::{ }; use itertools::Itertools; use plonk_verifier::{ - loader::evm::{encode_calldata, EvmLoader}, + loader::evm::{encode_calldata, EvmLoader, ExecutorBuilder}, pcs::kzg::{Gwc19, Kzg}, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::{self, PlonkVerifier}, @@ -231,16 +230,14 @@ fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) let success = { let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) - .build(Backend::new(MultiFork::new().0, None)); + .build(); let caller = Address::from_low_u64_be(0xfe); let verifier = evm - .deploy(caller, deployment_code.into(), 0.into(), None) - .unwrap() - .address; - let result = evm - .call_raw(caller, verifier, calldata.into(), 0.into()) + .deploy(caller, deployment_code.into(), 0.into()) + .address .unwrap(); + let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); dbg!(result.gas_used); diff --git a/src/loader.rs b/src/loader.rs index 5a040f9f..297390d0 100644 --- a/src/loader.rs +++ b/src/loader.rs @@ -5,7 +5,7 @@ use crate::{ }, Error, }; -use std::{fmt::Debug, iter}; +use std::{borrow::Cow, fmt::Debug, iter, ops::Deref}; pub mod native; @@ -86,7 +86,7 @@ pub trait EcPointLoader { ) -> Result<(), Error>; fn multi_scalar_multiplication( - pairs: &[(Self::LoadedScalar, Self::LoadedEcPoint)], + pairs: &[(&Self::LoadedScalar, &Self::LoadedEcPoint)], ) -> Self::LoadedEcPoint where Self: ScalarLoader; @@ -126,17 +126,18 @@ pub trait ScalarLoader { .chain(if constant == F::zero() { None } else { - Some(loader.load_const(&constant)) + Some(Cow::Owned(loader.load_const(&constant))) }) .chain(values.iter().map(|&(coeff, value)| { if coeff == F::one() { - value.clone() + Cow::Borrowed(value) } else { - loader.load_const(&coeff) * value + Cow::Owned(loader.load_const(&coeff) * value) } })) - .reduce(|acc, term| acc + term) + .reduce(|acc, term| Cow::Owned(acc.into_owned() + term.deref())) .unwrap() + .into_owned() } fn sum_products_with_coeff_and_const( diff --git a/src/loader/evm.rs b/src/loader/evm.rs index 7a07670c..fa80b97e 100644 --- a/src/loader/evm.rs +++ b/src/loader/evm.rs @@ -6,7 +6,9 @@ mod util; mod test; pub use loader::{EcPoint, EvmLoader, Scalar}; -pub use util::{encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, MemoryChunk}; +pub use util::{ + encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, ExecutorBuilder, MemoryChunk, +}; pub use ethereum_types::U256; diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs index 7d1dbb94..9e8b8e2a 100644 --- a/src/loader/evm/loader.rs +++ b/src/loader/evm/loader.rs @@ -37,7 +37,7 @@ impl PartialEq for Value { impl Value { fn identifier(&self) -> String { - match &self { + match self { Value::Constant(_) | Value::Memory(_) => format!("{:?}", self), Value::Negated(value) => format!("-({:?})", value), Value::Sum(lhs, rhs) => format!("({:?} + {:?})", lhs, rhs), @@ -222,13 +222,13 @@ impl EvmLoader { pub fn ec_point_from_limbs( self: &Rc, - x_limbs: [Scalar; LIMBS], - y_limbs: [Scalar; LIMBS], + x_limbs: [&Scalar; LIMBS], + y_limbs: [&Scalar; LIMBS], ) -> EcPoint { let ptr = self.allocate(0x40); for (ptr, limbs) in [(ptr, x_limbs), (ptr + 0x20, y_limbs)] { for (idx, limb) in limbs.into_iter().enumerate() { - self.push(&limb); + self.push(limb); // [..., success, acc] if idx > 0 { self.code @@ -769,10 +769,11 @@ where } fn multi_scalar_multiplication( - pairs: &[(>::LoadedScalar, EcPoint)], + pairs: &[(&>::LoadedScalar, &EcPoint)], ) -> EcPoint { pairs .iter() + .cloned() .map(|(scalar, ec_point)| match scalar.value { Value::Constant(constant) if U256::one() == constant => ec_point.clone(), _ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar), diff --git a/src/loader/evm/test.rs b/src/loader/evm/test.rs index e204c1b8..f81c5f54 100644 --- a/src/loader/evm/test.rs +++ b/src/loader/evm/test.rs @@ -1,10 +1,9 @@ -use crate::{loader::evm::test::tui::Tui, util::Itertools}; -use foundry_evm::{ - executor::{backend::Backend, fork::MultiFork, ExecutorBuilder}, - revm::{AccountInfo, Bytecode}, - utils::h256_to_u256_be, - Address, +use crate::{ + loader::evm::{test::tui::Tui, util::ExecutorBuilder}, + util::Itertools, }; +use ethereum_types::{Address, U256}; +use revm::{AccountInfo, Bytecode}; use std::env::var_os; mod tui; @@ -29,23 +28,20 @@ pub fn execute(code: Vec, calldata: Vec) -> (bool, u64, Vec) { let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) - .set_tracing(debug) .set_debugger(debug) - .build(Backend::new(MultiFork::new().0, None)); + .build(); - evm.backend_mut().insert_account_info( + evm.db_mut().insert_account_info( callee, AccountInfo::new(0.into(), 1, Bytecode::new_raw(code.into())), ); - let result = evm - .call_raw(caller, callee, calldata.into(), 0.into()) - .unwrap(); + let result = evm.call_raw(caller, callee, calldata.into(), 0.into()); let costs = result .logs .into_iter() - .map(|log| h256_to_u256_be(log.topics[0]).as_u64()) + .map(|log| U256::from_big_endian(log.topics[0].as_bytes()).as_u64()) .collect_vec(); if debug { diff --git a/src/loader/evm/test/tui.rs b/src/loader/evm/test/tui.rs index fcaef36c..c0c4d7f8 100644 --- a/src/loader/evm/test/tui.rs +++ b/src/loader/evm/test/tui.rs @@ -1,5 +1,6 @@ //! Copied and modified from https://github.com/foundry-rs/foundry/blob/master/ui/src/lib.rs +use crate::loader::evm::util::executor::{CallKind, DebugStep}; use crossterm::{ event::{ self, DisableMouseCapture, EnableMouseCapture, Event, KeyCode, KeyEvent, KeyModifiers, @@ -8,11 +9,8 @@ use crossterm::{ execute, terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen}, }; -use foundry_evm::{ - debug::{DebugStep, Instruction}, - revm::opcode, - Address, CallKind, -}; +use ethereum_types::Address; +use revm::opcode; use std::{ cmp::{max, min}, io, @@ -90,7 +88,7 @@ impl Tui { self.terminal.clear().unwrap(); let mut draw_memory: DrawMemory = DrawMemory::default(); - let debug_call: Vec<(Address, Vec, CallKind)> = self.debug_arena.clone(); + let debug_call = &self.debug_arena; let mut opcode_list: Vec = debug_call[0] .1 .iter() @@ -207,7 +205,7 @@ impl Tui { } KeyCode::Char('s') => { for _ in 0..Tui::buffer_as_number(&self.key_buffer, 1) { - let remaining_ops = opcode_list[self.current_step..].to_vec().clone(); + let remaining_ops = &opcode_list[self.current_step..]; self.current_step += remaining_ops .iter() .enumerate() @@ -233,7 +231,7 @@ impl Tui { } KeyCode::Char('a') => { for _ in 0..Tui::buffer_as_number(&self.key_buffer, 1) { - let prev_ops = opcode_list[..self.current_step].to_vec().clone(); + let prev_ops = &opcode_list[..self.current_step]; self.current_step = prev_ops .iter() .enumerate() @@ -618,12 +616,7 @@ impl Tui { .borders(Borders::ALL); let min_len = usize::max(format!("{}", stack.len()).len(), 2); - let indices_affected = - if let Instruction::OpCode(op) = debug_steps[current_step].instruction { - stack_indices_affected(op) - } else { - vec![] - }; + let indices_affected = stack_indices_affected(debug_steps[current_step].instruction.0); let text: Vec = stack .iter() @@ -699,33 +692,29 @@ impl Tui { let mut word = None; let mut color = None; - if let Instruction::OpCode(op) = debug_steps[current_step].instruction { - let stack_len = debug_steps[current_step].stack.len(); - if stack_len > 0 { - let w = debug_steps[current_step].stack[stack_len - 1]; - match op { - opcode::MLOAD => { - word = Some(w.as_usize() / 32); - color = Some(Color::Cyan); - } - opcode::MSTORE => { - word = Some(w.as_usize() / 32); - color = Some(Color::Red); - } - _ => {} + let stack_len = debug_steps[current_step].stack.len(); + if stack_len > 0 { + let w = debug_steps[current_step].stack[stack_len - 1]; + match debug_steps[current_step].instruction.0 { + opcode::MLOAD => { + word = Some(w.as_usize() / 32); + color = Some(Color::Cyan); } + opcode::MSTORE => { + word = Some(w.as_usize() / 32); + color = Some(Color::Red); + } + _ => {} } } if current_step > 0 { let prev_step = current_step - 1; let stack_len = debug_steps[prev_step].stack.len(); - if let Instruction::OpCode(op) = debug_steps[prev_step].instruction { - if op == opcode::MSTORE { - let prev_top = debug_steps[prev_step].stack[stack_len - 1]; - word = Some(prev_top.as_usize() / 32); - color = Some(Color::Green); - } + if debug_steps[prev_step].instruction.0 == opcode::MSTORE { + let prev_top = debug_steps[prev_step].stack[stack_len - 1]; + word = Some(prev_top.as_usize() / 32); + color = Some(Color::Green); } } diff --git a/src/loader/evm/util.rs b/src/loader/evm/util.rs index 0d9698bd..0a772513 100644 --- a/src/loader/evm/util.rs +++ b/src/loader/evm/util.rs @@ -5,6 +5,10 @@ use crate::{ use ethereum_types::U256; use std::iter; +pub(crate) mod executor; + +pub use executor::ExecutorBuilder; + pub struct MemoryChunk { ptr: usize, len: usize, diff --git a/src/loader/evm/util/executor.rs b/src/loader/evm/util/executor.rs new file mode 100644 index 00000000..ec9695e0 --- /dev/null +++ b/src/loader/evm/util/executor.rs @@ -0,0 +1,868 @@ +//! Copied and modified from https://github.com/foundry-rs/foundry/blob/master/evm/src/executor/mod.rs + +use bytes::Bytes; +use ethereum_types::{Address, H256, U256, U64}; +use revm::{ + evm_inner, opcode, spec_opcode_gas, Account, BlockEnv, CallInputs, CallScheme, CreateInputs, + CreateScheme, Database, DatabaseCommit, EVMData, Env, ExecutionResult, Gas, GasInspector, + InMemoryDB, Inspector, Interpreter, Memory, OpCode, Return, TransactOut, TransactTo, TxEnv, +}; +use sha3::{Digest, Keccak256}; +use std::{cell::RefCell, collections::HashMap, fmt::Display, rc::Rc}; + +macro_rules! return_ok { + () => { + Return::Continue | Return::Stop | Return::Return | Return::SelfDestruct + }; +} + +fn keccak256(data: impl AsRef<[u8]>) -> [u8; 32] { + Keccak256::digest(data.as_ref()).into() +} + +fn get_contract_address(sender: impl Into
, nonce: impl Into) -> Address { + let mut stream = rlp::RlpStream::new(); + stream.begin_list(2); + stream.append(&sender.into()); + stream.append(&nonce.into()); + + let hash = keccak256(&stream.out()); + + let mut bytes = [0u8; 20]; + bytes.copy_from_slice(&hash[12..]); + Address::from(bytes) +} + +fn get_create2_address( + from: impl Into
, + salt: [u8; 32], + init_code: impl Into, +) -> Address { + get_create2_address_from_hash(from, salt, keccak256(init_code.into().as_ref()).to_vec()) +} + +fn get_create2_address_from_hash( + from: impl Into
, + salt: [u8; 32], + init_code_hash: impl Into, +) -> Address { + let bytes = [ + &[0xff], + from.into().as_bytes(), + salt.as_slice(), + init_code_hash.into().as_ref(), + ] + .concat(); + + let hash = keccak256(&bytes); + + let mut bytes = [0u8; 20]; + bytes.copy_from_slice(&hash[12..]); + Address::from(bytes) +} + +fn get_create_address(call: &CreateInputs, nonce: u64) -> Address { + match call.scheme { + CreateScheme::Create => get_contract_address(call.caller, nonce), + CreateScheme::Create2 { salt } => { + let mut buffer: [u8; 4 * 8] = [0; 4 * 8]; + salt.to_big_endian(&mut buffer); + get_create2_address(call.caller, buffer, call.init_code.clone()) + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct Log { + pub address: Address, + pub topics: Vec, + pub data: Bytes, + pub block_hash: Option, + pub block_number: Option, + pub transaction_hash: Option, + pub transaction_index: Option, + pub log_index: Option, + pub transaction_log_index: Option, + pub log_type: Option, + pub removed: Option, +} + +#[derive(Clone, Debug, Default)] +struct LogCollector { + logs: Vec, +} + +impl Inspector for LogCollector { + fn log(&mut self, _: &mut EVMData<'_, DB>, address: &Address, topics: &[H256], data: &Bytes) { + self.logs.push(Log { + address: *address, + topics: topics.to_vec(), + data: data.clone(), + ..Default::default() + }); + } + + fn call( + &mut self, + _: &mut EVMData<'_, DB>, + call: &mut CallInputs, + _: bool, + ) -> (Return, Gas, Bytes) { + (Return::Continue, Gas::new(call.gas_limit), Bytes::new()) + } +} + +#[derive(Clone, Debug, Copy)] +pub enum CallKind { + Call, + StaticCall, + CallCode, + DelegateCall, + Create, + Create2, +} + +impl Default for CallKind { + fn default() -> Self { + CallKind::Call + } +} + +impl From for CallKind { + fn from(scheme: CallScheme) -> Self { + match scheme { + CallScheme::Call => CallKind::Call, + CallScheme::StaticCall => CallKind::StaticCall, + CallScheme::CallCode => CallKind::CallCode, + CallScheme::DelegateCall => CallKind::DelegateCall, + } + } +} + +impl From for CallKind { + fn from(create: CreateScheme) -> Self { + match create { + CreateScheme::Create => CallKind::Create, + CreateScheme::Create2 { .. } => CallKind::Create2, + } + } +} + +#[derive(Clone, Debug, Default)] +pub struct DebugArena { + pub arena: Vec, +} + +impl DebugArena { + fn push_node(&mut self, mut new_node: DebugNode) -> usize { + fn recursively_push( + arena: &mut Vec, + entry: usize, + mut new_node: DebugNode, + ) -> usize { + match new_node.depth { + _ if arena[entry].depth == new_node.depth - 1 => { + let id = arena.len(); + new_node.location = arena[entry].children.len(); + new_node.parent = Some(entry); + arena[entry].children.push(id); + arena.push(new_node); + id + } + _ => { + let child = *arena[entry].children.last().unwrap(); + recursively_push(arena, child, new_node) + } + } + } + + if self.arena.is_empty() { + self.arena.push(new_node); + 0 + } else if new_node.depth == 0 { + let id = self.arena.len(); + new_node.location = self.arena[0].children.len(); + new_node.parent = Some(0); + self.arena[0].children.push(id); + self.arena.push(new_node); + id + } else { + recursively_push(&mut self.arena, 0, new_node) + } + } + + #[cfg(test)] + pub fn flatten(&self, entry: usize) -> Vec<(Address, Vec, CallKind)> { + let node = &self.arena[entry]; + + let mut flattened = vec![]; + if !node.steps.is_empty() { + flattened.push((node.address, node.steps.clone(), node.kind)); + } + flattened.extend(node.children.iter().flat_map(|child| self.flatten(*child))); + + flattened + } +} + +#[derive(Clone, Debug, Default)] +pub struct DebugNode { + pub parent: Option, + pub children: Vec, + pub location: usize, + pub address: Address, + pub kind: CallKind, + pub depth: usize, + pub steps: Vec, +} + +#[derive(Clone, Debug)] +pub struct DebugStep { + pub stack: Vec, + pub memory: Memory, + pub instruction: Instruction, + pub push_bytes: Option>, + pub pc: usize, + pub total_gas_used: u64, +} + +impl Default for DebugStep { + fn default() -> Self { + Self { + stack: vec![], + memory: Memory::new(), + instruction: Instruction(revm::opcode::INVALID), + push_bytes: None, + pc: 0, + total_gas_used: 0, + } + } +} + +impl DebugStep { + #[cfg(test)] + pub fn pretty_opcode(&self) -> String { + if let Some(push_bytes) = &self.push_bytes { + format!("{}(0x{})", self.instruction, hex::encode(push_bytes)) + } else { + self.instruction.to_string() + } + } +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct Instruction(pub u8); + +impl From for Instruction { + fn from(op: u8) -> Instruction { + Instruction(op) + } +} + +impl Display for Instruction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + OpCode::try_from_u8(self.0).map_or_else( + || format!("UNDEFINED(0x{:02x})", self.0), + |opcode| opcode.as_str().to_string(), + ) + ) + } +} + +#[derive(Clone, Debug)] +struct Debugger { + arena: DebugArena, + head: usize, + context: Address, + gas_inspector: Rc>, +} + +impl Debugger { + fn new(gas_inspector: Rc>) -> Self { + Self { + arena: Default::default(), + head: Default::default(), + context: Default::default(), + gas_inspector, + } + } + + fn enter(&mut self, depth: usize, address: Address, kind: CallKind) { + self.context = address; + self.head = self.arena.push_node(DebugNode { + depth, + address, + kind, + ..Default::default() + }); + } + + fn exit(&mut self) { + if let Some(parent_id) = self.arena.arena[self.head].parent { + let DebugNode { + depth, + address, + kind, + .. + } = self.arena.arena[parent_id]; + self.context = address; + self.head = self.arena.push_node(DebugNode { + depth, + address, + kind, + ..Default::default() + }); + } + } +} + +impl Inspector for Debugger { + fn step( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + _is_static: bool, + ) -> Return { + let pc = interpreter.program_counter(); + let op = interpreter.contract.bytecode.bytecode()[pc]; + + let opcode_infos = spec_opcode_gas(data.env.cfg.spec_id); + let opcode_info = &opcode_infos[op as usize]; + + let push_size = if opcode_info.is_push() { + (op - opcode::PUSH1 + 1) as usize + } else { + 0 + }; + let push_bytes = match push_size { + 0 => None, + n => { + let start = pc + 1; + let end = start + n; + Some(interpreter.contract.bytecode.bytecode()[start..end].to_vec()) + } + }; + + let spent = interpreter.gas.limit() - self.gas_inspector.borrow().gas_remaining(); + let total_gas_used = spent - (interpreter.gas.refunded() as u64).min(spent / 5); + + self.arena.arena[self.head].steps.push(DebugStep { + pc, + stack: interpreter.stack().data().clone(), + memory: interpreter.memory.clone(), + instruction: Instruction(op), + push_bytes, + total_gas_used, + }); + + Return::Continue + } + + fn call( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CallInputs, + _: bool, + ) -> (Return, Gas, Bytes) { + self.enter( + data.journaled_state.depth() as usize, + call.context.code_address, + call.context.scheme.into(), + ); + + (Return::Continue, Gas::new(call.gas_limit), Bytes::new()) + } + + fn call_end( + &mut self, + _: &mut EVMData<'_, DB>, + _: &CallInputs, + gas: Gas, + status: Return, + retdata: Bytes, + _: bool, + ) -> (Return, Gas, Bytes) { + self.exit(); + + (status, gas, retdata) + } + + fn create( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CreateInputs, + ) -> (Return, Option
, Gas, Bytes) { + let nonce = data.journaled_state.account(call.caller).info.nonce; + self.enter( + data.journaled_state.depth() as usize, + get_create_address(call, nonce), + CallKind::Create, + ); + + ( + Return::Continue, + None, + Gas::new(call.gas_limit), + Bytes::new(), + ) + } + + fn create_end( + &mut self, + _: &mut EVMData<'_, DB>, + _: &CreateInputs, + status: Return, + address: Option
, + gas: Gas, + retdata: Bytes, + ) -> (Return, Option
, Gas, Bytes) { + self.exit(); + + (status, address, gas, retdata) + } +} + +#[macro_export] +macro_rules! call_inspectors { + ($id:ident, [ $($inspector:expr),+ ], $call:block) => { + $({ + if let Some($id) = $inspector { + $call; + } + })+ + } +} + +struct InspectorData { + logs: Vec, + debug: Option, +} + +#[derive(Default)] +struct InspectorStack { + gas: Option>>, + logs: Option, + debugger: Option, +} + +impl InspectorStack { + fn collect_inspector_states(self) -> InspectorData { + InspectorData { + logs: self.logs.map(|logs| logs.logs).unwrap_or_default(), + debug: self.debugger.map(|debugger| debugger.arena), + } + } +} + +impl Inspector for InspectorStack { + fn initialize_interp( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + is_static: bool, + ) -> Return { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let status = inspector.initialize_interp(interpreter, data, is_static); + + if status != Return::Continue { + return status; + } + } + ); + + Return::Continue + } + + fn step( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + is_static: bool, + ) -> Return { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let status = inspector.step(interpreter, data, is_static); + + if status != Return::Continue { + return status; + } + } + ); + + Return::Continue + } + + fn log( + &mut self, + evm_data: &mut EVMData<'_, DB>, + address: &Address, + topics: &[H256], + data: &Bytes, + ) { + call_inspectors!(inspector, [&mut self.logs], { + inspector.log(evm_data, address, topics, data); + }); + } + + fn step_end( + &mut self, + interpreter: &mut Interpreter, + data: &mut EVMData<'_, DB>, + is_static: bool, + status: Return, + ) -> Return { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let status = inspector.step_end(interpreter, data, is_static, status); + + if status != Return::Continue { + return status; + } + } + ); + + Return::Continue + } + + fn call( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CallInputs, + is_static: bool, + ) -> (Return, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (status, gas, retdata) = inspector.call(data, call, is_static); + + if status != Return::Continue { + return (status, gas, retdata); + } + } + ); + + (Return::Continue, Gas::new(call.gas_limit), Bytes::new()) + } + + fn call_end( + &mut self, + data: &mut EVMData<'_, DB>, + call: &CallInputs, + remaining_gas: Gas, + status: Return, + retdata: Bytes, + is_static: bool, + ) -> (Return, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (new_status, new_gas, new_retdata) = inspector.call_end( + data, + call, + remaining_gas, + status, + retdata.clone(), + is_static, + ); + + if new_status != status || (new_status == Return::Revert && new_retdata != retdata) + { + return (new_status, new_gas, new_retdata); + } + } + ); + + (status, remaining_gas, retdata) + } + + fn create( + &mut self, + data: &mut EVMData<'_, DB>, + call: &mut CreateInputs, + ) -> (Return, Option
, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (status, addr, gas, retdata) = inspector.create(data, call); + + if status != Return::Continue { + return (status, addr, gas, retdata); + } + } + ); + + ( + Return::Continue, + None, + Gas::new(call.gas_limit), + Bytes::new(), + ) + } + + fn create_end( + &mut self, + data: &mut EVMData<'_, DB>, + call: &CreateInputs, + status: Return, + address: Option
, + remaining_gas: Gas, + retdata: Bytes, + ) -> (Return, Option
, Gas, Bytes) { + call_inspectors!( + inspector, + [ + &mut self.gas.as_deref().map(|gas| gas.borrow_mut()), + &mut self.logs, + &mut self.debugger + ], + { + let (new_status, new_address, new_gas, new_retdata) = inspector.create_end( + data, + call, + status, + address, + remaining_gas, + retdata.clone(), + ); + + if new_status != status { + return (new_status, new_address, new_gas, new_retdata); + } + } + ); + + (status, address, remaining_gas, retdata) + } + + fn selfdestruct(&mut self) { + call_inspectors!(inspector, [&mut self.logs, &mut self.debugger], { + Inspector::::selfdestruct(inspector); + }); + } +} + +pub struct RawCallResult { + pub exit_reason: Return, + pub reverted: bool, + pub result: Bytes, + pub gas_used: u64, + pub gas_refunded: u64, + pub logs: Vec, + pub debug: Option, + pub state_changeset: Option>, + pub env: Env, + pub out: TransactOut, +} + +#[derive(Clone, Debug)] +pub struct DeployResult { + pub exit_reason: Return, + pub reverted: bool, + pub address: Option
, + pub gas_used: u64, + pub gas_refunded: u64, + pub logs: Vec, + pub debug: Option, + pub env: Env, +} + +#[derive(Debug, Default)] +pub struct ExecutorBuilder { + debugger: bool, + gas_limit: Option, +} + +impl ExecutorBuilder { + pub fn set_debugger(mut self, enable: bool) -> Self { + self.debugger = enable; + self + } + + pub fn with_gas_limit(mut self, gas_limit: U256) -> Self { + self.gas_limit = Some(gas_limit); + self + } + + pub fn build(self) -> Executor { + Executor::new(self.debugger, self.gas_limit.unwrap_or(U256::MAX)) + } +} + +#[derive(Clone, Debug)] +pub struct Executor { + db: InMemoryDB, + debugger: bool, + gas_limit: U256, +} + +impl Executor { + fn new(debugger: bool, gas_limit: U256) -> Self { + Executor { + db: InMemoryDB::default(), + debugger, + gas_limit, + } + } + + pub fn db_mut(&mut self) -> &mut InMemoryDB { + &mut self.db + } + + pub fn deploy(&mut self, from: Address, code: Bytes, value: U256) -> DeployResult { + let env = self.build_test_env(from, TransactTo::Create(CreateScheme::Create), code, value); + let result = self.call_raw_with_env(env); + self.commit(&result); + + let RawCallResult { + exit_reason, + out, + gas_used, + gas_refunded, + logs, + debug, + env, + .. + } = result; + + let address = match (exit_reason, out) { + (return_ok!(), TransactOut::Create(_, Some(address))) => Some(address), + _ => None, + }; + + DeployResult { + exit_reason, + reverted: !matches!(exit_reason, return_ok!()), + address, + gas_used, + gas_refunded, + logs, + debug, + env, + } + } + + pub fn call_raw( + &self, + from: Address, + to: Address, + calldata: Bytes, + value: U256, + ) -> RawCallResult { + let env = self.build_test_env(from, TransactTo::Call(to), calldata, value); + self.call_raw_with_env(env) + } + + fn call_raw_with_env(&self, mut env: Env) -> RawCallResult { + let mut inspector = self.inspector(); + let result = + evm_inner::<_, true>(&mut env, &mut self.db.clone(), &mut inspector).transact(); + let (exec_result, state_changeset) = result; + let ExecutionResult { + exit_reason, + gas_refunded, + gas_used, + out, + .. + } = exec_result; + + let result = match out { + TransactOut::Call(ref data) => data.to_owned(), + _ => Bytes::default(), + }; + let InspectorData { logs, debug } = inspector.collect_inspector_states(); + + RawCallResult { + exit_reason, + reverted: !matches!(exit_reason, return_ok!()), + result, + gas_used, + gas_refunded, + logs: logs.to_vec(), + debug, + state_changeset: Some(state_changeset.into_iter().collect()), + env, + out, + } + } + + fn commit(&mut self, result: &RawCallResult) { + if let Some(state_changeset) = result.state_changeset.as_ref() { + self.db + .commit(state_changeset.clone().into_iter().collect()); + } + } + + fn inspector(&self) -> InspectorStack { + let mut stack = InspectorStack { + logs: Some(LogCollector::default()), + ..Default::default() + }; + if self.debugger { + let gas_inspector = Rc::new(RefCell::new(GasInspector::default())); + stack.gas = Some(gas_inspector.clone()); + stack.debugger = Some(Debugger::new(gas_inspector)); + } + stack + } + + fn build_test_env( + &self, + caller: Address, + transact_to: TransactTo, + data: Bytes, + value: U256, + ) -> Env { + Env { + block: BlockEnv { + gas_limit: self.gas_limit, + ..BlockEnv::default() + }, + tx: TxEnv { + caller, + transact_to, + data, + value, + gas_limit: self.gas_limit.as_u64(), + ..TxEnv::default() + }, + ..Env::default() + } + } +} diff --git a/src/loader/halo2.rs b/src/loader/halo2.rs index 39e95c23..5b3361fa 100644 --- a/src/loader/halo2.rs +++ b/src/loader/halo2.rs @@ -23,7 +23,7 @@ mod util { Self: Sized, F: FnMut(B, V) -> B, { - self.into_iter().fold(Value::known(init), |acc, value| { + self.fold(Value::known(init), |acc, value| { acc.zip(value).map(|(acc, value)| f(acc, value)) }) } @@ -49,9 +49,7 @@ where self.transcript_initial_state .as_ref() .map(|transcript_initial_state| { - loader.assign_scalar(circuit::Value::known( - loader.scalar_chip().integer(*transcript_initial_state), - )) + loader.assign_scalar(circuit::Value::known(*transcript_initial_state)) }); Protocol { domain: self.domain.clone(), @@ -64,7 +62,7 @@ where quotient: self.quotient.clone(), transcript_initial_state, instance_committing_key: self.instance_committing_key.clone(), - linearization: self.linearization.clone(), + linearization: self.linearization, accumulator_indices: self.accumulator_indices.clone(), } } diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index d446c716..67c2b12d 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -45,15 +45,15 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc self.ctx.into_inner() } - pub fn ecc_chip(&self) -> Ref<'_, EccChip> { + pub fn ecc_chip(&self) -> Ref { self.ecc_chip.borrow() } - pub fn scalar_chip(&self) -> Ref<'_, EccChip::ScalarChip> { + pub fn scalar_chip(&self) -> Ref { Ref::map(self.ecc_chip(), |ecc_chip| ecc_chip.scalar_chip()) } - pub fn ctx(&self) -> Ref<'_, EccChip::Context> { + pub fn ctx(&self) -> Ref { self.ctx.borrow() } @@ -61,17 +61,15 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc self.ctx.borrow_mut() } - fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> Scalar<'a, C, EccChip> { - let assigned = self - .scalar_chip() + fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> EccChip::AssignedScalar { + self.scalar_chip() .assign_constant(&mut self.ctx_mut(), constant) - .unwrap(); - self.scalar_from_assigned(assigned) + .unwrap() } pub fn assign_scalar( self: &Rc, - scalar: circuit::Value, + scalar: circuit::Value, ) -> Scalar<'a, C, EccChip> { let assigned = self .scalar_chip() @@ -96,16 +94,14 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc Scalar { loader: self.clone(), index, - value, + value: value.into(), } } - fn assign_const_ec_point(self: &Rc, constant: C) -> EcPoint<'a, C, EccChip> { - self.ec_point_from_assigned( - self.ecc_chip() - .assign_constant(&mut self.ctx_mut(), constant) - .unwrap(), - ) + fn assign_const_ec_point(self: &Rc, constant: C) -> EccChip::AssignedEcPoint { + self.ecc_chip() + .assign_constant(&mut self.ctx_mut(), constant) + .unwrap() } pub fn assign_ec_point( @@ -135,7 +131,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc EcPoint { loader: self.clone(), index, - value, + value: value.into(), } } @@ -144,14 +140,14 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc lhs: &Scalar<'a, C, EccChip>, rhs: &Scalar<'a, C, EccChip>, ) -> Scalar<'a, C, EccChip> { - let output = match (&lhs.value, &rhs.value) { + let output = match (lhs.value().deref(), rhs.value().deref()) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs + rhs), (Value::Assigned(assigned), Value::Constant(constant)) | (Value::Constant(constant), Value::Assigned(assigned)) => self .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), assigned.clone())], + &[(C::Scalar::one(), assigned)], *constant, ) .map(Value::Assigned) @@ -160,10 +156,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[ - (C::Scalar::one(), lhs.clone()), - (C::Scalar::one(), rhs.clone()), - ], + &[(C::Scalar::one(), lhs), (C::Scalar::one(), rhs)], C::Scalar::zero(), ) .map(Value::Assigned) @@ -177,13 +170,13 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc lhs: &Scalar<'a, C, EccChip>, rhs: &Scalar<'a, C, EccChip>, ) -> Scalar<'a, C, EccChip> { - let output = match (&lhs.value, &rhs.value) { + let output = match (lhs.value().deref(), rhs.value().deref()) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs - rhs), (Value::Constant(constant), Value::Assigned(assigned)) => self .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(-C::Scalar::one(), assigned.clone())], + &[(-C::Scalar::one(), assigned)], *constant, ) .map(Value::Assigned) @@ -192,7 +185,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), assigned.clone())], + &[(C::Scalar::one(), assigned)], -*constant, ) .map(Value::Assigned) @@ -211,14 +204,14 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc lhs: &Scalar<'a, C, EccChip>, rhs: &Scalar<'a, C, EccChip>, ) -> Scalar<'a, C, EccChip> { - let output = match (&lhs.value, &rhs.value) { + let output = match (lhs.value().deref(), rhs.value().deref()) { (Value::Constant(lhs), Value::Constant(rhs)) => Value::Constant(*lhs * rhs), (Value::Assigned(assigned), Value::Constant(constant)) | (Value::Constant(constant), Value::Assigned(assigned)) => self .scalar_chip() .sum_with_coeff_and_const( &mut self.ctx_mut(), - &[(*constant, assigned.clone())], + &[(*constant, assigned)], C::Scalar::zero(), ) .map(Value::Assigned) @@ -227,7 +220,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .scalar_chip() .sum_products_with_coeff_and_const( &mut self.ctx_mut(), - &[(C::Scalar::one(), lhs.clone(), rhs.clone())], + &[(C::Scalar::one(), lhs, rhs)], C::Scalar::zero(), ) .map(Value::Assigned) @@ -237,7 +230,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc } fn neg(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { - let output = match &scalar.value { + let output = match scalar.value().deref() { Value::Constant(constant) => Value::Constant(constant.neg()), Value::Assigned(assigned) => { IntegerInstructions::neg(self.scalar_chip().deref(), &mut self.ctx_mut(), assigned) @@ -249,7 +242,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc } fn invert(self: &Rc, scalar: &Scalar<'a, C, EccChip>) -> Scalar<'a, C, EccChip> { - let output = match &scalar.value { + let output = match scalar.value().deref() { Value::Constant(constant) => Value::Constant(Field::invert(constant).unwrap()), Value::Assigned(assigned) => Value::Assigned( IntegerInstructions::invert( @@ -295,11 +288,30 @@ pub enum Value { Assigned(L), } +impl Value { + fn maybe_const(&self) -> Option + where + T: Copy, + { + match self { + Value::Constant(constant) => Some(*constant), + _ => None, + } + } + + fn assigned(&self) -> &L { + match self { + Value::Assigned(assigned) => assigned, + _ => unreachable!(), + } + } +} + #[derive(Clone)] pub struct Scalar<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { loader: Rc>, index: usize, - value: Value, + value: RefCell>, } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> { @@ -307,11 +319,19 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> &self.loader } - pub fn assigned(&self) -> EccChip::AssignedScalar { - match &self.value { - Value::Constant(constant) => self.loader.assign_const_scalar(*constant).assigned(), - Value::Assigned(assigned) => assigned.clone(), + pub fn assigned(&self) -> Ref { + if let Some(constant) = self.maybe_const() { + *self.value.borrow_mut() = Value::Assigned(self.loader.assign_const_scalar(constant)) } + Ref::map(self.value.borrow(), Value::assigned) + } + + fn value(&self) -> Ref> { + self.value.borrow() + } + + fn maybe_const(&self) -> Option { + self.value().deref().maybe_const() } } @@ -453,16 +473,31 @@ impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign<&'b Self pub struct EcPoint<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { loader: Rc>, index: usize, - value: Value, + value: RefCell>, } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip> { - pub fn assigned(&self) -> EccChip::AssignedEcPoint { - match &self.value { - Value::Constant(constant) => self.loader.assign_const_ec_point(*constant).assigned(), - Value::Assigned(assigned) => assigned.clone(), + pub fn into_assigned(self) -> EccChip::AssignedEcPoint { + match self.value.into_inner() { + Value::Constant(constant) => self.loader.assign_const_ec_point(constant), + Value::Assigned(assigned) => assigned, } } + + pub fn assigned(&self) -> Ref { + if let Some(constant) = self.maybe_const() { + *self.value.borrow_mut() = Value::Assigned(self.loader.assign_const_ec_point(constant)) + } + Ref::map(self.value.borrow(), Value::assigned) + } + + fn value(&self) -> Ref> { + self.value.borrow() + } + + fn maybe_const(&self) -> Option { + self.value().deref().maybe_const() + } } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> PartialEq for EcPoint<'a, C, EccChip> { @@ -558,35 +593,30 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader lhs: &EcPoint<'a, C, EccChip>, rhs: &EcPoint<'a, C, EccChip>, ) -> Result<(), crate::Error> { - match (&lhs.value, &rhs.value) { - (Value::Constant(lhs), Value::Constant(rhs)) => { - assert_eq!(lhs, rhs); - Ok(()) - } - (Value::Constant(constant), Value::Assigned(assigned)) - | (Value::Assigned(assigned), Value::Constant(constant)) => { - let constant = self.assign_const_ec_point(*constant).assigned(); - self.ecc_chip() - .assert_equal(&mut self.ctx_mut(), assigned, &constant) - .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) - } - (Value::Assigned(lhs), Value::Assigned(rhs)) => self - .ecc_chip() - .assert_equal(&mut self.ctx_mut(), lhs, rhs) - .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())), + if let (Value::Constant(lhs), Value::Constant(rhs)) = + (lhs.value().deref(), rhs.value().deref()) + { + assert_eq!(lhs, rhs); + Ok(()) + } else { + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + self.ecc_chip() + .assert_equal(&mut self.ctx_mut(), lhs.deref(), rhs.deref()) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) } } fn multi_scalar_multiplication( pairs: &[( - >::LoadedScalar, - EcPoint<'a, C, EccChip>, + &>::LoadedScalar, + &EcPoint<'a, C, EccChip>, )], ) -> EcPoint<'a, C, EccChip> { let loader = &pairs[0].0.loader; let (constant, fixed_base, variable_base_non_scaled, variable_base_scaled) = - pairs.iter().fold( + pairs.iter().cloned().fold( (C::identity(), Vec::new(), Vec::new(), Vec::new()), |( mut constant, @@ -594,22 +624,20 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader mut variable_base_non_scaled, mut variable_base_scaled, ), - (scalar, ec_point)| { - match (&ec_point.value, &scalar.value) { - (Value::Constant(ec_point), Value::Constant(scalar)) => { - constant = (*ec_point * scalar + constant).into() + (scalar, base)| { + match (scalar.value().deref(), base.value().deref()) { + (Value::Constant(scalar), Value::Constant(base)) => { + constant = (*base * scalar + constant).into() } - (Value::Constant(ec_point), Value::Assigned(scalar)) => { - fixed_base.push((scalar.clone(), *ec_point)) + (Value::Assigned(_), Value::Constant(base)) => { + fixed_base.push((scalar, *base)) } - (Value::Assigned(ec_point), Value::Constant(scalar)) + (Value::Constant(scalar), Value::Assigned(_)) if scalar.eq(&C::Scalar::one()) => { - variable_base_non_scaled.push(ec_point.clone()); - } - (Value::Assigned(ec_point), _) => { - variable_base_scaled.push((scalar.assigned(), ec_point.clone())) + variable_base_non_scaled.push(base); } + _ => variable_base_scaled.push((scalar, base)), }; ( constant, @@ -620,38 +648,47 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader }, ); - let fixed_base_msm = (!fixed_base.is_empty()).then(|| { - loader - .ecc_chip - .borrow_mut() - .fixed_base_msm(&mut loader.ctx_mut(), &fixed_base) - .unwrap() - }); - let variable_base_msm = (!variable_base_scaled.is_empty()).then(|| { - loader - .ecc_chip - .borrow_mut() - .variable_base_msm(&mut loader.ctx_mut(), &variable_base_scaled) - .unwrap() - }); + let fixed_base_msm = (!fixed_base.is_empty()) + .then(|| { + let fixed_base = fixed_base + .into_iter() + .map(|(scalar, base)| (scalar.assigned(), base)) + .collect_vec(); + loader + .ecc_chip + .borrow_mut() + .fixed_base_msm(&mut loader.ctx_mut(), &fixed_base) + .unwrap() + }) + .map(RefCell::new); + let variable_base_msm = (!variable_base_scaled.is_empty()) + .then(|| { + let variable_base_scaled = variable_base_scaled + .into_iter() + .map(|(scalar, base)| (scalar.assigned(), base.assigned())) + .collect_vec(); + loader + .ecc_chip + .borrow_mut() + .variable_base_msm(&mut loader.ctx_mut(), &variable_base_scaled) + .unwrap() + }) + .map(RefCell::new); let output = loader .ecc_chip() .sum_with_const( &mut loader.ctx_mut(), &variable_base_non_scaled .into_iter() - .chain(fixed_base_msm) - .chain(variable_base_msm) + .map(EcPoint::assigned) + .chain(fixed_base_msm.as_ref().map(RefCell::borrow)) + .chain(variable_base_msm.as_ref().map(RefCell::borrow)) .collect_vec(), constant, ) .unwrap(); - let normalized = loader - .ecc_chip() - .normalize(&mut loader.ctx_mut(), &output) - .unwrap(); - loader.ec_point_from_assigned(normalized) + loader.ec_point_from_assigned(output) } } diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 97921d24..1d7b6258 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -3,7 +3,7 @@ use halo2_proofs::{ circuit::{Cell, Value}, plonk::Error, }; -use std::fmt::Debug; +use std::{fmt::Debug, ops::Deref}; pub trait Context: Debug { fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error>; @@ -13,15 +13,13 @@ pub trait Context: Debug { pub trait IntegerInstructions<'a, F: FieldExt>: Clone + Debug { type Context: Context; - type Integer: Clone + Debug; + type AssignedCell: Clone + Debug; type AssignedInteger: Clone + Debug; - fn integer(&self, fe: F) -> Self::Integer; - fn assign_integer( &self, ctx: &mut Self::Context, - integer: Value, + integer: Value, ) -> Result; fn assign_constant( @@ -33,41 +31,45 @@ pub trait IntegerInstructions<'a, F: FieldExt>: Clone + Debug { fn sum_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F::Scalar, Self::AssignedInteger)], + values: &[(F::Scalar, impl Deref)], constant: F::Scalar, ) -> Result; fn sum_products_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F::Scalar, Self::AssignedInteger, Self::AssignedInteger)], + values: &[( + F::Scalar, + impl Deref, + impl Deref, + )], constant: F::Scalar, ) -> Result; fn sub( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, - b: &Self::AssignedInteger, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, ) -> Result; fn neg( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, + value: &Self::AssignedInteger, ) -> Result; fn invert( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, + value: &Self::AssignedInteger, ) -> Result; fn assert_equal( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, - b: &Self::AssignedInteger, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, ) -> Result<(), Error>; } @@ -77,57 +79,54 @@ pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { 'a, C::Scalar, Context = Self::Context, - Integer = Self::Scalar, + AssignedCell = Self::AssignedCell, AssignedInteger = Self::AssignedScalar, >; - type AssignedEcPoint: Clone + Debug; - type Scalar: Clone + Debug; + type AssignedCell: Clone + Debug; type AssignedScalar: Clone + Debug; + type AssignedEcPoint: Clone + Debug; fn scalar_chip(&self) -> &Self::ScalarChip; fn assign_constant( &self, ctx: &mut Self::Context, - point: C, + ec_point: C, ) -> Result; fn assign_point( &self, ctx: &mut Self::Context, - point: Value, + ec_point: Value, ) -> Result; fn sum_with_const( &self, ctx: &mut Self::Context, - values: &[Self::AssignedEcPoint], + values: &[impl Deref], constant: C, ) -> Result; fn fixed_base_msm( &mut self, ctx: &mut Self::Context, - pairs: &[(Self::AssignedScalar, C)], + pairs: &[(impl Deref, C)], ) -> Result; fn variable_base_msm( &mut self, ctx: &mut Self::Context, - pairs: &[(Self::AssignedScalar, Self::AssignedEcPoint)], - ) -> Result; - - fn normalize( - &self, - ctx: &mut Self::Context, - point: &Self::AssignedEcPoint, + pairs: &[( + impl Deref, + impl Deref, + )], ) -> Result; fn assert_equal( &self, ctx: &mut Self::Context, - a: &Self::AssignedEcPoint, - b: &Self::AssignedEcPoint, + lhs: &Self::AssignedEcPoint, + rhs: &Self::AssignedEcPoint, ) -> Result<(), Error>; } @@ -152,7 +151,7 @@ mod halo2_wrong { AssignedPoint, BaseFieldEccChip, }; use rand::rngs::OsRng; - use std::iter; + use std::{iter, ops::Deref}; impl<'a, F: FieldExt> Context for RegionCtx<'a, F> { fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { @@ -166,17 +165,13 @@ mod halo2_wrong { impl<'a, F: FieldExt> IntegerInstructions<'a, F> for MainGate { type Context = RegionCtx<'a, F>; - type Integer = F; + type AssignedCell = AssignedCell; type AssignedInteger = AssignedCell; - fn integer(&self, scalar: F) -> Self::Integer { - scalar - } - fn assign_integer( &self, ctx: &mut Self::Context, - integer: Value, + integer: Value, ) -> Result { self.assign_value(ctx, integer) } @@ -192,7 +187,7 @@ mod halo2_wrong { fn sum_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F, Self::AssignedInteger)], + values: &[(F, impl Deref)], constant: F, ) -> Result { self.compose( @@ -208,7 +203,11 @@ mod halo2_wrong { fn sum_products_with_coeff_and_const( &self, ctx: &mut Self::Context, - values: &[(F, Self::AssignedInteger, Self::AssignedInteger)], + values: &[( + F, + impl Deref, + impl Deref, + )], constant: F, ) -> Result { match values.len() { @@ -289,39 +288,39 @@ mod halo2_wrong { fn sub( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, - b: &Self::AssignedInteger, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, ) -> Result { - MainGateInstructions::sub(self, ctx, a, b) + MainGateInstructions::sub(self, ctx, lhs, rhs) } fn neg( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, + value: &Self::AssignedInteger, ) -> Result { - MainGateInstructions::neg_with_constant(self, ctx, a, F::zero()) + MainGateInstructions::neg_with_constant(self, ctx, value, F::zero()) } fn invert( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, + value: &Self::AssignedInteger, ) -> Result { - MainGateInstructions::invert_unsafe(self, ctx, a) + MainGateInstructions::invert_unsafe(self, ctx, value) } fn assert_equal( &self, ctx: &mut Self::Context, - a: &Self::AssignedInteger, - b: &Self::AssignedInteger, + lhs: &Self::AssignedInteger, + rhs: &Self::AssignedInteger, ) -> Result<(), Error> { let mut eq = true; - a.value().zip(b.value()).map(|(lhs, rhs)| { + lhs.value().zip(rhs.value()).map(|(lhs, rhs)| { eq &= lhs == rhs; }); - MainGateInstructions::assert_equal(self, ctx, a, b) + MainGateInstructions::assert_equal(self, ctx, lhs, rhs) .and(eq.then_some(()).ok_or(Error::Synthesis)) } } @@ -331,9 +330,9 @@ mod halo2_wrong { { type Context = RegionCtx<'a, C::Scalar>; type ScalarChip = MainGate; - type AssignedEcPoint = AssignedPoint; - type Scalar = C::Scalar; + type AssignedCell = AssignedCell; type AssignedScalar = AssignedCell; + type AssignedEcPoint = AssignedPoint; fn scalar_chip(&self) -> &Self::ScalarChip { self.main_gate() @@ -342,65 +341,79 @@ mod halo2_wrong { fn assign_constant( &self, ctx: &mut Self::Context, - point: C, + ec_point: C, ) -> Result { - self.assign_constant(ctx, point) + self.assign_constant(ctx, ec_point) } fn assign_point( &self, ctx: &mut Self::Context, - point: Value, + ec_point: Value, ) -> Result { - self.assign_point(ctx, point) + self.assign_point(ctx, ec_point) } fn sum_with_const( &self, ctx: &mut Self::Context, - values: &[Self::AssignedEcPoint], + values: &[impl Deref], constant: C, ) -> Result { if values.is_empty() { return self.assign_constant(ctx, constant); } - iter::empty() - .chain( - (!bool::from(constant.is_identity())) - .then(|| self.assign_constant(ctx, constant)), - ) - .chain(values.iter().cloned().map(Ok)) + let constant = (!bool::from(constant.is_identity())) + .then(|| self.assign_constant(ctx, constant)) + .transpose()?; + let output = iter::empty() + .chain(constant) + .chain(values.iter().map(|value| value.deref().clone())) + .map(Ok) .reduce(|acc, ec_point| self.add(ctx, &acc?, &ec_point?)) - .unwrap() + .unwrap()?; + self.normalize(ctx, &output) } fn fixed_base_msm( &mut self, ctx: &mut Self::Context, - pairs: &[(Self::AssignedScalar, C)], + pairs: &[(impl Deref, C)], ) -> Result { + assert!(!pairs.is_empty()); + // FIXME: Implement fixed base MSM in halo2_wrong let pairs = pairs .iter() + .filter(|(_, base)| !bool::from(base.is_identity())) .map(|(scalar, base)| { - Ok::<_, Error>((scalar.clone(), self.assign_constant(ctx, *base)?)) + Ok::<_, Error>((scalar.deref().clone(), self.assign_constant(ctx, *base)?)) }) .collect::, _>>()?; + let pairs = pairs + .iter() + .map(|(scalar, base)| (scalar, base)) + .collect_vec(); self.variable_base_msm(ctx, &pairs) } fn variable_base_msm( &mut self, ctx: &mut Self::Context, - pairs: &[(Self::AssignedScalar, Self::AssignedEcPoint)], + pairs: &[( + impl Deref, + impl Deref, + )], ) -> Result { + assert!(!pairs.is_empty()); + const WINDOW_SIZE: usize = 3; let pairs = pairs .iter() - .map(|(scalar, base)| (base.clone(), scalar.clone())) + .map(|(scalar, base)| (base.deref().clone(), scalar.deref().clone())) .collect_vec(); - match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { + let output = match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { Err(_) => { if self.assign_aux(ctx, WINDOW_SIZE, pairs.len()).is_err() { let aux_generator = Value::known(C::Curve::random(OsRng).into()); @@ -410,30 +423,23 @@ mod halo2_wrong { self.mul_batch_1d_horizontal(ctx, pairs, WINDOW_SIZE) } result => result, - } - } - - fn normalize( - &self, - ctx: &mut Self::Context, - point: &Self::AssignedEcPoint, - ) -> Result { - self.normalize(ctx, point) + }?; + self.normalize(ctx, &output) } fn assert_equal( &self, ctx: &mut Self::Context, - a: &Self::AssignedEcPoint, - b: &Self::AssignedEcPoint, + lhs: &Self::AssignedEcPoint, + rhs: &Self::AssignedEcPoint, ) -> Result<(), Error> { let mut eq = true; - [(a.x(), b.x()), (a.y(), b.y())].map(|(lhs, rhs)| { + [(lhs.x(), rhs.x()), (lhs.y(), rhs.y())].map(|(lhs, rhs)| { lhs.integer().zip(rhs.integer()).map(|(lhs, rhs)| { eq &= lhs.value() == rhs.value(); }); }); - self.assert_equal(ctx, a, b) + self.assert_equal(ctx, lhs, rhs) .and(eq.then_some(()).ok_or(Error::Synthesis)) } } diff --git a/src/loader/native.rs b/src/loader/native.rs index 1451ff76..6fce383a 100644 --- a/src/loader/native.rs +++ b/src/loader/native.rs @@ -54,10 +54,11 @@ impl EcPointLoader for NativeLoader { } fn multi_scalar_multiplication( - pairs: &[(>::LoadedScalar, C)], + pairs: &[(&>::LoadedScalar, &C)], ) -> C { pairs .iter() + .cloned() .map(|(scalar, base)| *base * scalar) .reduce(|acc, value| acc + value) .unwrap() diff --git a/src/pcs.rs b/src/pcs.rs index 65804895..23fdda2a 100644 --- a/src/pcs.rs +++ b/src/pcs.rs @@ -123,7 +123,7 @@ where L: Loader, PCS: PolynomialCommitmentScheme, { - fn from_repr(repr: Vec) -> Result; + fn from_repr(repr: &[&L::LoadedScalar]) -> Result; } impl AccumulatorEncoding for () @@ -132,7 +132,7 @@ where L: Loader, PCS: PolynomialCommitmentScheme, { - fn from_repr(_: Vec) -> Result { + fn from_repr(_: &[&L::LoadedScalar]) -> Result { unimplemented!() } } diff --git a/src/pcs/kzg.rs b/src/pcs/kzg.rs index 9f10bd44..056589a8 100644 --- a/src/pcs/kzg.rs +++ b/src/pcs/kzg.rs @@ -15,6 +15,9 @@ pub use accumulator::{KzgAccumulator, LimbsEncoding}; pub use decider::KzgDecidingKey; pub use multiopen::{Bdfg21, Bdfg21Proof, Gwc19, Gwc19Proof}; +#[cfg(feature = "loader_halo2")] +pub use accumulator::LimbsEncodingInstructions; + #[derive(Clone, Debug)] pub struct Kzg(PhantomData<(M, MOS)>); diff --git a/src/pcs/kzg/accumulation.rs b/src/pcs/kzg/accumulation.rs index cd13fd00..4273ce9a 100644 --- a/src/pcs/kzg/accumulation.rs +++ b/src/pcs/kzg/accumulation.rs @@ -44,16 +44,16 @@ where ) -> Result { let (lhs, rhs) = instances .iter() - .cloned() - .map(|accumulator| (accumulator.lhs, accumulator.rhs)) - .chain(proof.blind.clone()) + .map(|accumulator| (&accumulator.lhs, &accumulator.rhs)) + .chain(proof.blind.as_ref().map(|(lhs, rhs)| (lhs, rhs))) .unzip::<_, _, Vec<_>, Vec<_>>(); let powers_of_r = proof.r.powers(lhs.len()); - let [lhs, rhs] = [lhs, rhs].map(|msms| { - msms.into_iter() + let [lhs, rhs] = [lhs, rhs].map(|bases| { + bases + .into_iter() .zip(powers_of_r.iter()) - .map(|(msm, r)| Msm::::base(msm) * r) + .map(|(base, r)| Msm::::base(base) * r) .sum::>() .evaluate(None) }); @@ -184,7 +184,7 @@ where let powers_of_r = r.powers(lhs.len()); let [lhs, rhs] = [lhs, rhs].map(|msms| { - msms.into_iter() + msms.iter() .zip(powers_of_r.iter()) .map(|(msm, power_of_r)| Msm::::base(msm) * power_of_r) .sum::>() diff --git a/src/pcs/kzg/accumulator.rs b/src/pcs/kzg/accumulator.rs index 17c7bf83..50ce6ba8 100644 --- a/src/pcs/kzg/accumulator.rs +++ b/src/pcs/kzg/accumulator.rs @@ -53,13 +53,22 @@ mod native { Accumulator = KzgAccumulator, >, { - fn from_repr(limbs: Vec) -> Result { + fn from_repr(limbs: &[&C::Scalar]) -> Result { assert_eq!(limbs.len(), 4 * LIMBS); let [lhs_x, lhs_y, rhs_x, rhs_y]: [_; 4] = limbs .chunks(LIMBS) .into_iter() - .map(|limbs| fe_from_limbs::<_, _, LIMBS, BITS>(limbs.try_into().unwrap())) + .map(|limbs| { + fe_from_limbs::<_, _, LIMBS, BITS>( + limbs + .iter() + .map(|limb| **limb) + .collect_vec() + .try_into() + .unwrap(), + ) + }) .collect_vec() .try_into() .unwrap(); @@ -100,7 +109,7 @@ mod evm { Accumulator = KzgAccumulator>, >, { - fn from_repr(limbs: Vec) -> Result { + fn from_repr(limbs: &[&Scalar]) -> Result { assert_eq!(limbs.len(), 4 * LIMBS); let loader = limbs[0].loader(); @@ -122,10 +131,13 @@ mod evm { } } +#[cfg(feature = "loader_halo2")] +pub use halo2::LimbsEncodingInstructions; + #[cfg(feature = "loader_halo2")] mod halo2 { use crate::{ - loader::halo2::{Context, EccInstructions, Halo2Loader, Scalar, Valuetools}, + loader::halo2::{EccInstructions, Halo2Loader, Scalar, Valuetools}, pcs::{ kzg::{KzgAccumulator, LimbsEncoding}, AccumulatorEncoding, PolynomialCommitmentScheme, @@ -136,19 +148,18 @@ mod halo2 { }, Error, }; - use halo2_proofs::circuit::Value; - use halo2_wrong_ecc::{maingate::AssignedValue, AssignedPoint}; - use std::{iter, rc::Rc}; + use halo2_proofs::{circuit::Value, plonk}; + use std::{iter, ops::Deref, rc::Rc}; - fn ec_point_from_assigned_limbs( - limbs: &[AssignedValue], + fn ec_point_from_limbs( + limbs: &[Value<&C::Scalar>], ) -> Value { assert_eq!(limbs.len(), 2 * LIMBS); let [x, y] = [&limbs[..LIMBS], &limbs[LIMBS..]].map(|limbs| { limbs .iter() - .map(|assigned| assigned.value()) + .cloned() .fold_zipped(Vec::new(), |mut acc, limb| { acc.push(*limb); acc @@ -159,6 +170,22 @@ mod halo2 { x.zip(y).map(|(x, y)| C::from_xy(x, y).unwrap()) } + pub trait LimbsEncodingInstructions<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize>: + EccInstructions<'a, C> + { + fn assign_ec_point_from_limbs( + &self, + ctx: &mut Self::Context, + limbs: &[impl Deref], + ) -> Result; + + fn assign_ec_point_to_limbs( + &self, + ctx: &mut Self::Context, + ec_point: impl Deref, + ) -> Result, plonk::Error>; + } + impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> AccumulatorEncoding>, PCS> for LimbsEncoding where @@ -168,41 +195,72 @@ mod halo2 { Rc>, Accumulator = KzgAccumulator>>, >, - EccChip: EccInstructions< - 'a, - C, - AssignedEcPoint = AssignedPoint<::Base, C::Scalar, LIMBS, BITS>, - AssignedScalar = AssignedValue, - >, + EccChip: LimbsEncodingInstructions<'a, C, LIMBS, BITS>, { - fn from_repr(limbs: Vec>) -> Result { + fn from_repr(limbs: &[&Scalar<'a, C, EccChip>]) -> Result { assert_eq!(limbs.len(), 4 * LIMBS); let loader = limbs[0].loader(); - let assigned_limbs = limbs.iter().map(|limb| limb.assigned()).collect_vec(); - let [lhs, rhs] = [&assigned_limbs[..2 * LIMBS], &assigned_limbs[2 * LIMBS..]].map( - |assigned_limbs| { - let ec_point = ec_point_from_assigned_limbs::<_, LIMBS, BITS>(assigned_limbs); - loader.assign_ec_point(ec_point) - }, - ); - - for (src, dst) in assigned_limbs.iter().zip( - iter::empty() - .chain(lhs.assigned().x().limbs()) - .chain(lhs.assigned().y().limbs()) - .chain(rhs.assigned().x().limbs()) - .chain(rhs.assigned().y().limbs()), - ) { - loader - .ctx_mut() - .constrain_equal(src.cell(), dst.as_ref().cell()) + let [lhs, rhs] = [&limbs[..2 * LIMBS], &limbs[2 * LIMBS..]].map(|limbs| { + let assigned = loader + .ecc_chip() + .assign_ec_point_from_limbs( + &mut loader.ctx_mut(), + &limbs.iter().map(|limb| limb.assigned()).collect_vec(), + ) .unwrap(); + loader.ec_point_from_assigned(assigned) + }); + + Ok(KzgAccumulator::new(lhs, rhs)) + } + } + + mod halo2_wrong { + use super::*; + use halo2_wrong_ecc::BaseFieldEccChip; + + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> + LimbsEncodingInstructions<'a, C, LIMBS, BITS> for BaseFieldEccChip + { + fn assign_ec_point_from_limbs( + &self, + ctx: &mut Self::Context, + limbs: &[impl Deref], + ) -> Result { + assert_eq!(limbs.len(), 2 * LIMBS); + + let ec_point = self.assign_point( + ctx, + ec_point_from_limbs::<_, LIMBS, BITS>( + &limbs.iter().map(|limb| limb.value()).collect_vec(), + ), + )?; + + for (src, dst) in limbs.iter().zip_eq( + iter::empty() + .chain(ec_point.x().limbs()) + .chain(ec_point.y().limbs()), + ) { + ctx.constrain_equal(src.cell(), dst.as_ref().cell())?; + } + + Ok(ec_point) } - let accumulator = KzgAccumulator::new(lhs, rhs); - Ok(accumulator) + fn assign_ec_point_to_limbs( + &self, + _: &mut Self::Context, + ec_point: impl Deref, + ) -> Result, plonk::Error> { + Ok(iter::empty() + .chain(ec_point.x().limbs()) + .chain(ec_point.y().limbs()) + .map(|limb| limb.as_ref()) + .cloned() + .collect()) + } } } } diff --git a/src/pcs/kzg/decider.rs b/src/pcs/kzg/decider.rs index b6957883..de3e2a06 100644 --- a/src/pcs/kzg/decider.rs +++ b/src/pcs/kzg/decider.rs @@ -144,7 +144,7 @@ mod evm { let powers_of_challenge = LoadedScalar::::powers(&challenge, lhs.len()); let [lhs, rhs] = [lhs, rhs].map(|msms| { - msms.into_iter() + msms.iter() .zip(powers_of_challenge.iter()) .map(|(msm, power_of_challenge)| { Msm::>::base(msm) * power_of_challenge diff --git a/src/pcs/kzg/multiopen/bdfg21.rs b/src/pcs/kzg/multiopen/bdfg21.rs index 7f70ca61..f542f750 100644 --- a/src/pcs/kzg/multiopen/bdfg21.rs +++ b/src/pcs/kzg/multiopen/bdfg21.rs @@ -47,8 +47,8 @@ where queries: &[Query], proof: &Bdfg21Proof, ) -> Result { + let sets = query_sets(queries); let f = { - let sets = query_sets(queries); let coeffs = query_set_coeffs(&sets, z, &proof.z_prime); let powers_of_mu = proof @@ -62,10 +62,10 @@ where msms.zip(proof.gamma.powers(sets.len()).into_iter()) .map(|(msm, power_of_gamma)| msm * &power_of_gamma) .sum::>() - - Msm::base(proof.w.clone()) * &coeffs[0].z_s + - Msm::base(&proof.w) * &coeffs[0].z_s }; - let rhs = Msm::base(proof.w_prime.clone()); + let rhs = Msm::base(&proof.w_prime); let lhs = f + rhs.clone() * &proof.z_prime; Ok(KzgAccumulator::new( @@ -143,7 +143,7 @@ fn query_sets(queries: &[Query]) -> Vec(queries: &[Query]) -> Vec(queries: &[Query]) -> Vec>( - sets: &[QuerySet], +fn query_set_coeffs<'a, F: FieldExt, T: LoadedScalar>( + sets: &[QuerySet<'a, F, T>], z: &T, z_prime: &T, ) -> Vec> { @@ -211,17 +211,17 @@ fn query_set_coeffs>( } #[derive(Clone, Debug)] -struct QuerySet { +struct QuerySet<'a, F, T> { shifts: Vec, polys: Vec, - evals: Vec>, + evals: Vec>, } -impl> QuerySet { +impl<'a, F: FieldExt, T: LoadedScalar> QuerySet<'a, F, T> { fn msm>( &self, coeff: &QuerySetCoeff, - commitments: &[Msm], + commitments: &[Msm<'a, C, L>], powers_of_mu: &[T], ) -> Msm { self.polys @@ -241,7 +241,7 @@ impl> QuerySet { &coeff .eval_coeffs .iter() - .zip(evals.iter()) + .zip(evals.iter().cloned()) .map(|(coeff, eval)| (coeff.evaluated(), eval)) .collect_vec(), ) * coeff.r_eval_coeff.as_ref().unwrap().evaluated(); @@ -288,18 +288,15 @@ where }) .collect_vec(); - let z = &powers_of_z[1].clone(); + let z = &powers_of_z[1]; let z_pow_k_minus_one = { let k_minus_one = shifts.len() - 1; powers_of_z .iter() .enumerate() .skip(1) - .filter_map(|(i, power_of_z)| { - (k_minus_one & (1 << i) == 1).then(|| power_of_z.clone()) - }) - .reduce(|acc, value| acc * value) - .unwrap_or_else(|| loader.load_one()) + .filter_map(|(i, power_of_z)| (k_minus_one & (1 << i) == 1).then(|| power_of_z)) + .fold(loader.load_one(), |acc, value| acc * value) }; let barycentric_weights = shifts @@ -354,7 +351,7 @@ where .map(Fraction::evaluated) .collect_vec(), ); - self.r_eval_coeff = Some(match self.commitment_coeff.clone() { + self.r_eval_coeff = Some(match self.commitment_coeff.as_ref() { Some(coeff) => Fraction::new(coeff.evaluated().clone(), barycentric_weights_sum), None => Fraction::one_over(barycentric_weights_sum), }); diff --git a/src/pcs/kzg/multiopen/gwc19.rs b/src/pcs/kzg/multiopen/gwc19.rs index 121fce8a..6e3f579f 100644 --- a/src/pcs/kzg/multiopen/gwc19.rs +++ b/src/pcs/kzg/multiopen/gwc19.rs @@ -55,15 +55,13 @@ where .map(|(msm, power_of_u)| msm * power_of_u) .sum::>() }; - let z_omegas = sets - .iter() - .map(|set| z.clone() * &z.loader().load_const(&set.shift)); + let z_omegas = sets.iter().map(|set| z.loader().load_const(&set.shift) * z); let rhs = proof .ws .iter() .zip(powers_of_u.iter()) - .map(|(w, power_of_u)| Msm::base(w.clone()) * power_of_u) + .map(|(w, power_of_u)| Msm::base(w) * power_of_u) .collect_vec(); let lhs = f + rhs .iter() @@ -105,25 +103,25 @@ where } } -struct QuerySet { +struct QuerySet<'a, F, T> { shift: F, polys: Vec, - evals: Vec, + evals: Vec<&'a T>, } -impl QuerySet +impl<'a, F, T> QuerySet<'a, F, T> where F: PrimeField, T: Clone, { fn msm>( &self, - commitments: &[Msm], + commitments: &[Msm<'a, C, L>], powers_of_v: &[L::LoadedScalar], ) -> Msm { self.polys .iter() - .zip(self.evals.iter()) + .zip(self.evals.iter().cloned()) .map(|(poly, eval)| { let commitment = commitments[*poly].clone(); commitment - Msm::constant(eval.clone()) @@ -142,12 +140,12 @@ where queries.iter().fold(Vec::new(), |mut sets, query| { if let Some(pos) = sets.iter().position(|set| set.shift == query.shift) { sets[pos].polys.push(query.poly); - sets[pos].evals.push(query.eval.clone()); + sets[pos].evals.push(&query.eval); } else { sets.push(QuerySet { shift: query.shift, polys: vec![query.poly], - evals: vec![query.eval.clone()], + evals: vec![&query.eval], }); } sets diff --git a/src/system/halo2.rs b/src/system/halo2.rs index 3d7cf140..a49fa6ef 100644 --- a/src/system/halo2.rs +++ b/src/system/halo2.rs @@ -601,11 +601,14 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .zip(permutation_fixeds.chunks(self.permutation_chunk_size)) .enumerate() .map( - |(i, ((((z, z_w, _), (_, z_next_w, _)), polys), permutation_fixeds))| { + |( + i, + ((((z, z_omega, _), (_, z_next_omega, _)), polys), permutation_fixeds), + )| { let left = if self.zk || zs.len() == 1 { - z_w.clone() + z_omega.clone() } else { - z_w + l_last * (z_next_w - z_w) + z_omega + l_last * (z_next_omega - z_omega) } * polys .iter() .zip(permutation_fixeds.iter()) @@ -675,7 +678,10 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .iter() .zip(polys.iter()) .flat_map( - |(lookup, (z, z_w, permuted_input, permuted_input_w_inv, permuted_table))| { + |( + lookup, + (z, z_omega, permuted_input, permuted_input_omega_inv, permuted_table), + )| { let input = compress(lookup.input_expressions()); let table = compress(lookup.table_expressions()); iter::empty() @@ -683,20 +689,20 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .chain(self.zk.then(|| l_last * (z * z - z))) .chain(Some(if self.zk { l_active - * (z_w * (permuted_input + beta) * (permuted_table + gamma) + * (z_omega * (permuted_input + beta) * (permuted_table + gamma) - z * (input + beta) * (table + gamma)) } else { - z_w * (permuted_input + beta) * (permuted_table + gamma) + z_omega * (permuted_input + beta) * (permuted_table + gamma) - z * (input + beta) * (table + gamma) })) .chain(self.zk.then(|| l_0 * (permuted_input - permuted_table))) .chain(Some(if self.zk { l_active * (permuted_input - permuted_table) - * (permuted_input - permuted_input_w_inv) + * (permuted_input - permuted_input_omega_inv) } else { (permuted_input - permuted_table) - * (permuted_input - permuted_input_w_inv) + * (permuted_input - permuted_input_omega_inv) })) }, ) diff --git a/src/system/halo2/test/kzg.rs b/src/system/halo2/test/kzg.rs index 3b7e1694..2b267758 100644 --- a/src/system/halo2/test/kzg.rs +++ b/src/system/halo2/test/kzg.rs @@ -27,8 +27,8 @@ pub fn main_gate_with_range_with_mock_kzg_accumulator( let srs = read_or_create_srs(TESTDATA_DIR, 1, setup::); let [g1, s_g1] = [srs.get_g()[0], srs.get_g()[1]].map(|point| point.coordinates().unwrap()); MainGateWithRange::new( - [*s_g1.x(), *s_g1.y(), *g1.x(), *g1.y()] - .iter() + [s_g1.x(), s_g1.y(), g1.x(), g1.y()] + .into_iter() .cloned() .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) .collect(), diff --git a/src/system/halo2/test/kzg/halo2.rs b/src/system/halo2/test/kzg/halo2.rs index 18767c98..11a6046f 100644 --- a/src/system/halo2/test/kzg/halo2.rs +++ b/src/system/halo2/test/kzg/halo2.rs @@ -7,7 +7,7 @@ use crate::{ pcs::{ kzg::{ Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgAsProvingKey, KzgAsVerifyingKey, - KzgSuccinctVerifyingKey, LimbsEncoding, + KzgSuccinctVerifyingKey, LimbsEncoding, LimbsEncodingInstructions, }, AccumulationScheme, AccumulationSchemeProver, }, @@ -31,7 +31,7 @@ use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2_proofs::{ circuit::{floor_planner::V1, Layouter, Value}, plonk, - plonk::Circuit, + plonk::{Circuit, Error}, poly::{ commitment::ParamsProver, kzg::{ @@ -48,7 +48,7 @@ use halo2_wrong_ecc::{ }; use paste::paste; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; -use std::{iter, rc::Rc}; +use std::rc::Rc; const T: usize = 5; const RATE: usize = 4; @@ -281,14 +281,14 @@ impl Circuit for Accumulation { range_chip.load_table(&mut layouter)?; - let (lhs, rhs) = layouter.assign_region( + let accumulator_limbs = layouter.assign_region( || "", |region| { let ctx = RegionCtx::new(region, 0); let ecc_chip = config.ecc_chip(); let loader = Halo2Loader::new(ecc_chip, ctx); - let KzgAccumulator { lhs, rhs } = accumulate( + let accumulator = accumulate( &self.svk, &loader, &self.snarks, @@ -296,21 +296,26 @@ impl Circuit for Accumulation { self.as_proof(), ); + let accumulator_limbs = [accumulator.lhs, accumulator.rhs] + .iter() + .map(|ec_point| { + loader + .ecc_chip() + .assign_ec_point_to_limbs(&mut loader.ctx_mut(), ec_point.assigned()) + }) + .collect::, Error>>()? + .into_iter() + .flatten(); + loader.print_row_metering(); println!("Total row cost: {}", loader.ctx().offset()); - Ok((lhs.assigned(), rhs.assigned())) + Ok(accumulator_limbs) }, )?; - for (limb, row) in iter::empty() - .chain(lhs.x().limbs()) - .chain(lhs.y().limbs()) - .chain(rhs.x().limbs()) - .chain(rhs.y().limbs()) - .zip(0..) - { - main_gate.expose_public(layouter.namespace(|| ""), limb.into(), row)?; + for (row, limb) in accumulator_limbs.enumerate() { + main_gate.expose_public(layouter.namespace(|| ""), limb, row)?; } Ok(()) diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs index bb7e8e23..ab7d548c 100644 --- a/src/system/halo2/transcript/halo2.rs +++ b/src/system/halo2/transcript/halo2.rs @@ -1,60 +1,58 @@ use crate::{ loader::{ - halo2::{EcPoint, EccInstructions, Halo2Loader, IntegerInstructions, Scalar}, + halo2::{EcPoint, EccInstructions, Halo2Loader, Scalar}, native::{self, NativeLoader}, Loader, ScalarLoader, }, util::{ - arithmetic::{fe_to_fe, CurveAffine, FieldExt, PrimeField}, + arithmetic::{fe_to_fe, CurveAffine, PrimeField}, hash::Poseidon, transcript::{Transcript, TranscriptRead, TranscriptWrite}, Itertools, }, Error, }; -use halo2_proofs::{ - circuit::{AssignedCell, Value}, - transcript::EncodedChallenge, -}; +use halo2_proofs::{circuit::Value, transcript::EncodedChallenge}; use std::{ io::{self, Read, Write}, - marker::PhantomData, rc::Rc, }; -pub trait EncodeNative<'a, C: CurveAffine, N: FieldExt>: EccInstructions<'a, C> { - fn encode_native( +/// Encoding that encodes elliptic curve point into native field elements. +pub trait NativeEncoding<'a, C>: EccInstructions<'a, C> +where + C: CurveAffine, +{ + fn encode( &self, ctx: &mut Self::Context, ec_point: &Self::AssignedEcPoint, - ) -> Result>, Error>; + ) -> Result, Error>; } pub struct PoseidonTranscript< - C: CurveAffine, - L: Loader, + C, + L, S, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize, -> { +> where + C: CurveAffine, + L: Loader, +{ loader: L, stream: S, buf: Poseidon>::LoadedScalar, T, RATE>, - _marker: PhantomData, } -impl< - 'a, - C: CurveAffine, - R: Read, - EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > PoseidonTranscript>, Value, T, RATE, R_F, R_P> +impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> + PoseidonTranscript>, Value, T, RATE, R_F, R_P> +where + C: CurveAffine, + R: Read, + EccChip: NativeEncoding<'a, C>, { pub fn new(loader: &Rc>, stream: Value) -> Self { let buf = Poseidon::new(loader, R_F, R_P); @@ -62,22 +60,17 @@ impl< loader: loader.clone(), stream, buf, - _marker: PhantomData, } } } -impl< - 'a, - C: CurveAffine, - R: Read, - EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > Transcript>> +impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> + Transcript>> for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +where + C: CurveAffine, + R: Read, + EccChip: NativeEncoding<'a, C>, { fn loader(&self) -> &Rc> { &self.loader @@ -96,30 +89,31 @@ impl< let encoded = self .loader .ecc_chip() - .encode_native(&mut self.loader.ctx_mut(), &ec_point.assigned()) + .encode(&mut self.loader.ctx_mut(), &ec_point.assigned()) .map(|encoded| { encoded .into_iter() .map(|encoded| self.loader.scalar_from_assigned(encoded)) .collect_vec() }) - .map_err(|_| Error::Transcript(io::ErrorKind::Other, "".to_string()))?; + .map_err(|_| { + Error::Transcript( + io::ErrorKind::Other, + "Failed to encode elliptic curve point into native field elements".to_string(), + ) + })?; self.buf.update(&encoded); Ok(()) } } -impl< - 'a, - C: CurveAffine, - R: Read, - EccChip: EncodeNative<'a, C, C::Scalar, AssignedScalar = AssignedCell>, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptRead>> +impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> + TranscriptRead>> for PoseidonTranscript>, Value, T, RATE, R_F, R_P> +where + C: CurveAffine, + R: Read, + EccChip: NativeEncoding<'a, C>, { fn read_scalar(&mut self) -> Result, Error> { let scalar = self.stream.as_mut().and_then(|stream| { @@ -128,7 +122,7 @@ impl< return Value::unknown(); } Option::::from(C::Scalar::from_repr(data)) - .map(|scalar| Value::known(self.loader.scalar_chip().integer(scalar))) + .map(Value::known) .unwrap_or_else(Value::unknown) }); let scalar = self.loader.assign_scalar(scalar); @@ -152,8 +146,6 @@ impl< } } -// - impl PoseidonTranscript { @@ -162,7 +154,6 @@ impl TranscriptRead - for PoseidonTranscript +impl + TranscriptRead for PoseidonTranscript +where + C: CurveAffine, + R: Read, { fn read_scalar(&mut self) -> Result { let mut data = ::Repr::default(); @@ -243,14 +230,11 @@ impl< } } -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > PoseidonTranscript +impl + PoseidonTranscript +where + C: CurveAffine, + W: Write, { pub fn stream_mut(&mut self) -> &mut W { &mut self.stream @@ -261,14 +245,11 @@ impl< } } -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > TranscriptWrite for PoseidonTranscript +impl TranscriptWrite + for PoseidonTranscript +where + C: CurveAffine, + W: Write, { fn write_scalar(&mut self, scalar: C::Scalar) -> Result<(), Error> { self.common_scalar(&scalar)?; @@ -332,15 +313,12 @@ impl halo2_proofs::transcript::TranscriptRead> +impl + halo2_proofs::transcript::TranscriptRead> for PoseidonTranscript +where + C: CurveAffine, + R: Read, { fn read_point(&mut self) -> io::Result { match TranscriptRead::read_ec_point(self) { @@ -359,30 +337,24 @@ impl< } } -impl< - C: CurveAffine, - R: Read, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptReadBuffer> +impl + halo2_proofs::transcript::TranscriptReadBuffer> for PoseidonTranscript +where + C: CurveAffine, + R: Read, { fn init(reader: R) -> Self { Self::new(reader) } } -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptWrite> +impl + halo2_proofs::transcript::TranscriptWrite> for PoseidonTranscript +where + C: CurveAffine, + W: Write, { fn write_point(&mut self, ec_point: C) -> io::Result<()> { halo2_proofs::transcript::Transcript::>::common_point( @@ -399,15 +371,12 @@ impl< } } -impl< - C: CurveAffine, - W: Write, - const T: usize, - const RATE: usize, - const R_F: usize, - const R_P: usize, - > halo2_proofs::transcript::TranscriptWriterBuffer> +impl + halo2_proofs::transcript::TranscriptWriterBuffer> for PoseidonTranscript +where + C: CurveAffine, + W: Write, { fn init(writer: W) -> Self { Self::new(writer) @@ -419,15 +388,15 @@ impl< } mod halo2_wrong { - use crate::system::halo2::transcript::halo2::EncodeNative; + use crate::system::halo2::transcript::halo2::NativeEncoding; use halo2_curves::CurveAffine; use halo2_proofs::circuit::AssignedCell; use halo2_wrong_ecc::BaseFieldEccChip; - impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> EncodeNative<'a, C, C::Scalar> + impl<'a, C: CurveAffine, const LIMBS: usize, const BITS: usize> NativeEncoding<'a, C> for BaseFieldEccChip { - fn encode_native( + fn encode( &self, _: &mut Self::Context, ec_point: &Self::AssignedEcPoint, diff --git a/src/util/arithmetic.rs b/src/util/arithmetic.rs index b9a5c7c6..02d6da32 100644 --- a/src/util/arithmetic.rs +++ b/src/util/arithmetic.rs @@ -186,14 +186,14 @@ impl Fraction { self.eval = Some( self.numer - .as_ref() - .map(|numer| numer.clone() * &self.denom) + .take() + .map(|numer| numer * &self.denom) .unwrap_or_else(|| self.denom.clone()), ); } pub fn evaluated(&self) -> &T { - assert!(self.inv); + assert!(self.eval.is_some()); self.eval.as_ref().unwrap() } @@ -241,19 +241,12 @@ pub fn fe_to_limbs [F2; LIMBS] { let big = BigUint::from_bytes_le(fe.to_repr().as_ref()); - let mask = (BigUint::one() << BITS) - 1usize; + let mask = &((BigUint::one() << BITS) - 1usize); (0usize..) .step_by(BITS) .take(LIMBS) - .map(move |shift| fe_from_big((&big >> shift) & &mask)) + .map(|shift| fe_from_big((&big >> shift) & mask)) .collect_vec() .try_into() .unwrap() } - -pub fn powers(scalar: F) -> impl Iterator -where - for<'a> F: Mul<&'a F, Output = F> + One + Clone, -{ - iter::successors(Some(F::one()), move |power| Some(scalar.clone() * power)) -} diff --git a/src/util/msm.rs b/src/util/msm.rs index a15eab9d..f35ea197 100644 --- a/src/util/msm.rs +++ b/src/util/msm.rs @@ -9,13 +9,13 @@ use std::{ }; #[derive(Clone, Debug)] -pub struct Msm> { +pub struct Msm<'a, C: CurveAffine, L: Loader> { constant: Option, scalars: Vec, - bases: Vec, + bases: Vec<&'a L::LoadedEcPoint>, } -impl Default for Msm +impl<'a, C, L> Default for Msm<'a, C, L> where C: CurveAffine, L: Loader, @@ -29,7 +29,7 @@ where } } -impl Msm +impl<'a, C, L> Msm<'a, C, L> where C: CurveAffine, L: Loader, @@ -41,7 +41,7 @@ where } } - pub fn base(base: L::LoadedEcPoint) -> Self { + pub fn base<'b: 'a>(base: &'b L::LoadedEcPoint) -> Self { let one = base.loader().load_one(); Msm { scalars: vec![one], @@ -72,8 +72,12 @@ where .ec_point_load_const(&gen) }); let pairs = iter::empty() - .chain(self.constant.map(|constant| (constant, gen.unwrap()))) - .chain(self.scalars.into_iter().zip(self.bases.into_iter())) + .chain( + self.constant + .as_ref() + .map(|constant| (constant, gen.as_ref().unwrap())), + ) + .chain(self.scalars.iter().zip(self.bases.into_iter())) .collect_vec(); L::multi_scalar_multiplication(&pairs) } @@ -87,16 +91,16 @@ where } } - pub fn push(&mut self, scalar: L::LoadedScalar, base: L::LoadedEcPoint) { + pub fn push<'b: 'a>(&mut self, scalar: L::LoadedScalar, base: &'b L::LoadedEcPoint) { if let Some(pos) = self.bases.iter().position(|exist| exist.eq(&base)) { - self.scalars[pos] += scalar; + self.scalars[pos] += &scalar; } else { self.scalars.push(scalar); self.bases.push(base); } } - pub fn extend(&mut self, mut other: Self) { + pub fn extend<'b: 'a>(&mut self, mut other: Msm<'b, C, L>) { match (self.constant.as_mut(), other.constant.as_ref()) { (Some(lhs), Some(rhs)) => *lhs += rhs, (None, Some(_)) => self.constant = other.constant.take(), @@ -108,58 +112,62 @@ where } } -impl Add> for Msm +impl<'a, 'b, C, L> Add> for Msm<'a, C, L> where + 'b: 'a, C: CurveAffine, L: Loader, { - type Output = Msm; + type Output = Msm<'a, C, L>; - fn add(mut self, rhs: Msm) -> Self::Output { + fn add(mut self, rhs: Msm<'b, C, L>) -> Self::Output { self.extend(rhs); self } } -impl AddAssign> for Msm +impl<'a, 'b, C, L> AddAssign> for Msm<'a, C, L> where + 'b: 'a, C: CurveAffine, L: Loader, { - fn add_assign(&mut self, rhs: Msm) { + fn add_assign(&mut self, rhs: Msm<'b, C, L>) { self.extend(rhs); } } -impl Sub> for Msm +impl<'a, 'b, C, L> Sub> for Msm<'a, C, L> where + 'b: 'a, C: CurveAffine, L: Loader, { - type Output = Msm; + type Output = Msm<'a, C, L>; - fn sub(mut self, rhs: Msm) -> Self::Output { + fn sub(mut self, rhs: Msm<'b, C, L>) -> Self::Output { self.extend(-rhs); self } } -impl SubAssign> for Msm +impl<'a, 'b, C, L> SubAssign> for Msm<'a, C, L> where + 'b: 'a, C: CurveAffine, L: Loader, { - fn sub_assign(&mut self, rhs: Msm) { + fn sub_assign(&mut self, rhs: Msm<'b, C, L>) { self.extend(-rhs); } } -impl Mul<&L::LoadedScalar> for Msm +impl<'a, C, L> Mul<&L::LoadedScalar> for Msm<'a, C, L> where C: CurveAffine, L: Loader, { - type Output = Msm; + type Output = Msm<'a, C, L>; fn mul(mut self, rhs: &L::LoadedScalar) -> Self::Output { self.scale(rhs); @@ -167,7 +175,7 @@ where } } -impl MulAssign<&L::LoadedScalar> for Msm +impl<'a, C, L> MulAssign<&L::LoadedScalar> for Msm<'a, C, L> where C: CurveAffine, L: Loader, @@ -177,13 +185,13 @@ where } } -impl Neg for Msm +impl<'a, C, L> Neg for Msm<'a, C, L> where C: CurveAffine, L: Loader, { - type Output = Msm; - fn neg(mut self) -> Msm { + type Output = Msm<'a, C, L>; + fn neg(mut self) -> Msm<'a, C, L> { self.constant = self.constant.map(|constant| -constant); for scalar in self.scalars.iter_mut() { *scalar = -scalar.clone(); @@ -192,7 +200,7 @@ where } } -impl Sum for Msm +impl<'a, C, L> Sum for Msm<'a, C, L> where C: CurveAffine, L: Loader, diff --git a/src/util/protocol.rs b/src/util/protocol.rs index e9ecd9f0..f7747060 100644 --- a/src/util/protocol.rs +++ b/src/util/protocol.rs @@ -41,7 +41,7 @@ where quotient: self.quotient.clone(), transcript_initial_state, instance_committing_key: self.instance_committing_key.clone(), - linearization: self.linearization.clone(), + linearization: self.linearization, accumulator_indices: self.accumulator_indices.clone(), } } @@ -82,11 +82,11 @@ where let langranges = langranges.into_iter().sorted().dedup().collect_vec(); let one = loader.load_one(); - let zn_minus_one = zn.clone() - one; + let zn_minus_one = zn.clone() - &one; let zn_minus_one_inv = Fraction::one_over(zn_minus_one.clone()); let n_inv = loader.load_const(&domain.n_inv); - let numer = zn_minus_one.clone() * n_inv; + let numer = zn_minus_one.clone() * &n_inv; let omegas = langranges .iter() .map(|&i| loader.load_const(&domain.rotate_scalar(C::Scalar::one(), Rotation(i)))) @@ -378,7 +378,7 @@ fn merge_left_right(a: Option>, b: Option>) -> O } } -#[derive(Clone, Debug)] +#[derive(Clone, Copy, Debug)] pub enum LinearizationStrategy { /// Older linearization strategy of GWC19, which has linearization /// polynomial that doesn't evaluate to 0, and requires prover to send extra diff --git a/src/verifier/plonk.rs b/src/verifier/plonk.rs index c3276af4..df208949 100644 --- a/src/verifier/plonk.rs +++ b/src/verifier/plonk.rs @@ -137,8 +137,8 @@ where instances .iter() .zip(bases.iter()) - .map(|(scalar, base)| Msm::::base(base.clone()) * scalar) - .chain(constant.clone().map(|constant| Msm::base(constant))) + .map(|(scalar, base)| Msm::::base(base) * scalar) + .chain(constant.as_ref().map(Msm::base)) .sum::>() .evaluate(None) }) @@ -190,12 +190,13 @@ where .accumulator_indices .iter() .map(|accumulator_indices| { - accumulator_indices - .iter() - .map(|&(i, j)| instances[i][j].clone()) - .collect() + AE::from_repr( + &accumulator_indices + .iter() + .map(|&(i, j)| &instances[i][j]) + .collect_vec(), + ) }) - .map(AE::from_repr) .collect::, _>>()?; Ok(Self { @@ -241,25 +242,20 @@ where .collect() } - fn commitments( - &self, - protocol: &Protocol, + fn commitments<'a>( + &'a self, + protocol: &'a Protocol, common_poly_eval: &CommonPolynomialEvaluation, evaluations: &mut HashMap, ) -> Result>, Error> { let loader = common_poly_eval.zn().loader(); let mut commitments = iter::empty() - .chain( - protocol - .preprocessed - .iter() - .map(|value| Msm::base(value.clone())), - ) + .chain(protocol.preprocessed.iter().map(Msm::base)) .chain( self.committed_instances - .clone() + .as_ref() .map(|committed_instances| { - committed_instances.into_iter().map(Msm::base).collect_vec() + committed_instances.iter().map(Msm::base).collect_vec() }) .unwrap_or_else(|| { iter::repeat_with(Default::default) @@ -267,7 +263,7 @@ where .collect_vec() }), ) - .chain(self.witnesses.iter().cloned().map(Msm::base)) + .chain(self.witnesses.iter().map(Msm::base)) .collect_vec(); let numerator = protocol.quotient.numerator.evaluate( @@ -314,7 +310,7 @@ where .pow_const(protocol.quotient.chunk_degree as u64) .powers(self.quotients.len()) .into_iter() - .zip(self.quotients.iter().cloned().map(Msm::base)) + .zip(self.quotients.iter().map(Msm::base)) .map(|(coeff, chunk)| chunk * &coeff) .sum::>(); match protocol.linearization { From 5c31088f8530470227c411b5d168b445cb59858b Mon Sep 17 00:00:00 2001 From: Han Date: Tue, 8 Nov 2022 19:58:00 -0800 Subject: [PATCH 18/28] feat: implement ipa pcs and accumulation (#14) --- Cargo.toml | 7 +- src/loader/halo2/test.rs | 21 +- src/pcs.rs | 1 + src/pcs/ipa.rs | 447 ++++++++++++++++++++++++++++ src/pcs/ipa/accumulation.rs | 291 ++++++++++++++++++ src/pcs/ipa/accumulator.rs | 21 ++ src/pcs/ipa/decider.rs | 57 ++++ src/pcs/ipa/multiopen.rs | 3 + src/pcs/ipa/multiopen/bgh19.rs | 417 ++++++++++++++++++++++++++ src/system/halo2.rs | 1 + src/system/halo2/strategy.rs | 53 ++++ src/system/halo2/test.rs | 1 + src/system/halo2/test/ipa.rs | 143 +++++++++ src/system/halo2/test/ipa/native.rs | 59 ++++ src/system/halo2/test/kzg/halo2.rs | 9 +- src/util.rs | 40 +++ src/util/arithmetic.rs | 12 + src/util/msm.rs | 123 +++++++- src/util/poly.rs | 175 +++++++++++ 19 files changed, 1866 insertions(+), 15 deletions(-) create mode 100644 src/pcs/ipa.rs create mode 100644 src/pcs/ipa/accumulation.rs create mode 100644 src/pcs/ipa/accumulator.rs create mode 100644 src/pcs/ipa/decider.rs create mode 100644 src/pcs/ipa/multiopen.rs create mode 100644 src/pcs/ipa/multiopen/bgh19.rs create mode 100644 src/system/halo2/strategy.rs create mode 100644 src/system/halo2/test/ipa.rs create mode 100644 src/system/halo2/test/ipa/native.rs create mode 100644 src/util/poly.rs diff --git a/Cargo.toml b/Cargo.toml index e5bbb0d2..44dcd60b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,9 @@ rand = "0.8" hex = "0.4" halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } +# parallel +rayon = { version = "1.5.3", optional = true } + # system_halo2 halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_10_22", optional = true } @@ -41,13 +44,13 @@ tui = { version = "0.19", default-features = false, features = ["crossterm"] } [features] default = ["loader_evm", "loader_halo2", "system_halo2"] +parallel = ["dep:rayon"] + loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] loader_halo2 = ["dep:halo2_proofs", "dep:halo2_wrong_ecc", "dep:poseidon"] system_halo2 = ["dep:halo2_proofs"] -sanity_check = [] - [[example]] name = "evm-verifier" required-features = ["loader_evm", "system_halo2"] diff --git a/src/loader/halo2/test.rs b/src/loader/halo2/test.rs index 08551fe0..dd2fccaa 100644 --- a/src/loader/halo2/test.rs +++ b/src/loader/halo2/test.rs @@ -4,6 +4,7 @@ use crate::{ }; use halo2_proofs::circuit::Value; +#[derive(Clone, Debug)] pub struct Snark { pub protocol: Protocol, pub instances: Vec>, @@ -27,6 +28,7 @@ impl Snark { } } +#[derive(Clone, Debug)] pub struct SnarkWitness { pub protocol: Protocol, pub instances: Vec>>, @@ -48,18 +50,23 @@ impl From> for SnarkWitness { } impl SnarkWitness { - pub fn without_witnesses(&self) -> Self { + pub fn new_without_witness(protocol: Protocol) -> Self { + let instances = protocol + .num_instance + .iter() + .map(|num_instance| vec![Value::unknown(); *num_instance]) + .collect(); SnarkWitness { - protocol: self.protocol.clone(), - instances: self - .instances - .iter() - .map(|instances| vec![Value::unknown(); instances.len()]) - .collect(), + protocol, + instances, proof: Value::unknown(), } } + pub fn without_witnesses(&self) -> Self { + SnarkWitness::new_without_witness(self.protocol.clone()) + } + pub fn proof(&self) -> Value<&[u8]> { self.proof.as_ref().map(Vec::as_slice) } diff --git a/src/pcs.rs b/src/pcs.rs index 23fdda2a..bf944a43 100644 --- a/src/pcs.rs +++ b/src/pcs.rs @@ -10,6 +10,7 @@ use crate::{ use rand::Rng; use std::fmt::Debug; +pub mod ipa; pub mod kzg; pub trait PolynomialCommitmentScheme: Clone + Debug diff --git a/src/pcs/ipa.rs b/src/pcs/ipa.rs new file mode 100644 index 00000000..a2b34824 --- /dev/null +++ b/src/pcs/ipa.rs @@ -0,0 +1,447 @@ +use crate::{ + loader::{native::NativeLoader, LoadedScalar, Loader, ScalarLoader}, + pcs::PolynomialCommitmentScheme, + util::{ + arithmetic::{ + inner_product, powers, Curve, CurveAffine, Domain, Field, Fraction, PrimeField, + }, + msm::{multi_scalar_multiplication, Msm}, + parallelize, + poly::Polynomial, + transcript::{TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, +}; +use rand::Rng; +use std::{fmt::Debug, iter, marker::PhantomData}; + +mod accumulation; +mod accumulator; +mod decider; +mod multiopen; + +pub use accumulation::{IpaAs, IpaAsProof}; +pub use accumulator::IpaAccumulator; +pub use decider::IpaDecidingKey; +pub use multiopen::{Bgh19, Bgh19Proof, Bgh19SuccinctVerifyingKey}; + +#[derive(Clone, Debug)] +pub struct Ipa(PhantomData<(C, MOS)>); + +impl PolynomialCommitmentScheme for Ipa +where + C: CurveAffine, + L: Loader, + MOS: Clone + Debug, +{ + type Accumulator = IpaAccumulator; +} + +impl Ipa +where + C: CurveAffine, +{ + pub fn create_proof( + pk: &IpaProvingKey, + p: &[C::Scalar], + z: &C::Scalar, + omega: Option<&C::Scalar>, + transcript: &mut T, + mut rng: R, + ) -> Result, Error> + where + T: TranscriptWrite, + R: Rng, + { + let mut p_prime = Polynomial::new(p.to_vec()); + if pk.zk() { + let p_bar = { + let mut p_bar = Polynomial::rand(p.len(), &mut rng); + let p_bar_at_z = p_bar.evaluate(*z); + p_bar[0] -= p_bar_at_z; + p_bar + }; + let omega_bar = C::Scalar::random(&mut rng); + let c_bar = pk.commit(&p_bar, Some(omega_bar)); + transcript.write_ec_point(c_bar)?; + + let alpha = transcript.squeeze_challenge(); + let omega_prime = *omega.unwrap() + alpha * omega_bar; + transcript.write_scalar(omega_prime)?; + + p_prime = p_prime + &(p_bar * alpha); + }; + + let xi_0 = transcript.squeeze_challenge(); + let h_prime = pk.h * xi_0; + let mut bases = pk.g.clone(); + let mut coeffs = p_prime.to_vec(); + let mut zs = powers(*z).take(coeffs.len()).collect_vec(); + + let k = pk.domain.k; + let mut xi = Vec::with_capacity(k); + for i in 0..k { + let half = 1 << (k - i - 1); + + let l_i = multi_scalar_multiplication(&coeffs[half..], &bases[..half]) + + h_prime * inner_product(&coeffs[half..], &zs[..half]); + let r_i = multi_scalar_multiplication(&coeffs[..half], &bases[half..]) + + h_prime * inner_product(&coeffs[..half], &zs[half..]); + transcript.write_ec_point(l_i.to_affine())?; + transcript.write_ec_point(r_i.to_affine())?; + + let xi_i = transcript.squeeze_challenge(); + let xi_i_inv = Field::invert(&xi_i).unwrap(); + + let (bases_l, bases_r) = bases.split_at_mut(half); + let (coeffs_l, coeffs_r) = coeffs.split_at_mut(half); + let (zs_l, zs_r) = zs.split_at_mut(half); + parallelize(bases_l, |(bases_l, start)| { + let mut tmp = Vec::with_capacity(bases_l.len()); + for (lhs, rhs) in bases_l.iter().zip(bases_r[start..].iter()) { + tmp.push(lhs.to_curve() + *rhs * xi_i); + } + C::Curve::batch_normalize(&tmp, bases_l); + }); + parallelize(coeffs_l, |(coeffs_l, start)| { + for (lhs, rhs) in coeffs_l.iter_mut().zip(coeffs_r[start..].iter()) { + *lhs += xi_i_inv * rhs; + } + }); + parallelize(zs_l, |(zs_l, start)| { + for (lhs, rhs) in zs_l.iter_mut().zip(zs_r[start..].iter()) { + *lhs += xi_i * rhs; + } + }); + bases = bases_l.to_vec(); + coeffs = coeffs_l.to_vec(); + zs = zs_l.to_vec(); + + xi.push(xi_i); + } + + transcript.write_ec_point(bases[0])?; + transcript.write_scalar(coeffs[0])?; + + Ok(IpaAccumulator::new(xi, bases[0])) + } + + pub fn read_proof>( + svk: &IpaSuccinctVerifyingKey, + transcript: &mut T, + ) -> Result, Error> + where + T: TranscriptRead, + { + IpaProof::read(svk, transcript) + } + + pub fn succinct_verify>( + svk: &IpaSuccinctVerifyingKey, + commitment: &Msm, + z: &L::LoadedScalar, + eval: &L::LoadedScalar, + proof: &IpaProof, + ) -> Result, Error> { + let loader = z.loader(); + let h = loader.ec_point_load_const(&svk.h); + let s = svk.s.as_ref().map(|s| loader.ec_point_load_const(s)); + let h = Msm::::base(&h); + + let h_prime = h * &proof.xi_0; + let lhs = { + let c_prime = match ( + s.as_ref(), + proof.c_bar_alpha.as_ref(), + proof.omega_prime.as_ref(), + ) { + (Some(s), Some((c_bar, alpha)), Some(omega_prime)) => { + let s = Msm::::base(s); + commitment.clone() + Msm::base(c_bar) * alpha - s * omega_prime + } + (None, None, None) => commitment.clone(), + _ => unreachable!(), + }; + let c_0 = c_prime + h_prime.clone() * eval; + let c_k = c_0 + + proof + .rounds + .iter() + .zip(proof.xi_inv().iter()) + .flat_map(|(Round { l, r, xi }, xi_inv)| [(l, xi_inv), (r, xi)]) + .map(|(base, scalar)| Msm::::base(base) * scalar) + .sum::>(); + c_k.evaluate(None) + }; + let rhs = { + let u = Msm::::base(&proof.u); + let v_prime = h_eval(&proof.xi(), z) * &proof.c; + (u * &proof.c + h_prime * &v_prime).evaluate(None) + }; + + loader.ec_point_assert_eq("C_k == c[U] + v'[H']", &lhs, &rhs)?; + + Ok(IpaAccumulator::new(proof.xi(), proof.u.clone())) + } +} + +#[derive(Clone, Debug)] +pub struct IpaProvingKey { + pub domain: Domain, + pub g: Vec, + pub h: C, + pub s: Option, +} + +impl IpaProvingKey { + pub fn new(domain: Domain, g: Vec, h: C, s: Option) -> Self { + Self { domain, g, h, s } + } + + pub fn zk(&self) -> bool { + self.s.is_some() + } + + pub fn svk(&self) -> IpaSuccinctVerifyingKey { + IpaSuccinctVerifyingKey::new(self.domain.clone(), self.h, self.s) + } + + pub fn dk(&self) -> IpaDecidingKey { + IpaDecidingKey::new(self.g.clone()) + } + + pub fn commit(&self, poly: &Polynomial, omega: Option) -> C { + let mut c = multi_scalar_multiplication(&poly[..], &self.g); + match (self.s, omega) { + (Some(s), Some(omega)) => c += s * omega, + (None, None) => {} + _ => unreachable!(), + }; + c.to_affine() + } +} + +impl IpaProvingKey { + #[cfg(test)] + pub fn rand(k: usize, zk: bool, mut rng: R) -> Self { + use crate::util::arithmetic::{root_of_unity, Group}; + + let domain = Domain::new(k, root_of_unity(k)); + let mut g = vec![C::default(); 1 << k]; + C::Curve::batch_normalize( + &iter::repeat_with(|| C::Curve::random(&mut rng)) + .take(1 << k) + .collect_vec(), + &mut g, + ); + let h = C::Curve::random(&mut rng).to_affine(); + let s = zk.then(|| C::Curve::random(&mut rng).to_affine()); + Self { domain, g, h, s } + } +} + +#[derive(Clone, Debug)] +pub struct IpaSuccinctVerifyingKey { + pub domain: Domain, + pub h: C, + pub s: Option, +} + +impl IpaSuccinctVerifyingKey { + pub fn new(domain: Domain, h: C, s: Option) -> Self { + Self { domain, h, s } + } + + pub fn zk(&self) -> bool { + self.s.is_some() + } +} + +#[derive(Clone, Debug)] +pub struct IpaProof +where + C: CurveAffine, + L: Loader, +{ + c_bar_alpha: Option<(L::LoadedEcPoint, L::LoadedScalar)>, + omega_prime: Option, + xi_0: L::LoadedScalar, + rounds: Vec>, + u: L::LoadedEcPoint, + c: L::LoadedScalar, +} + +impl IpaProof +where + C: CurveAffine, + L: Loader, +{ + pub fn new( + c_bar_alpha: Option<(L::LoadedEcPoint, L::LoadedScalar)>, + omega_prime: Option, + xi_0: L::LoadedScalar, + rounds: Vec>, + u: L::LoadedEcPoint, + c: L::LoadedScalar, + ) -> Self { + Self { + c_bar_alpha, + omega_prime, + xi_0, + rounds, + u, + c, + } + } + + pub fn read(svk: &IpaSuccinctVerifyingKey, transcript: &mut T) -> Result + where + T: TranscriptRead, + { + let c_bar_alpha = svk + .zk() + .then(|| { + let c_bar = transcript.read_ec_point()?; + let alpha = transcript.squeeze_challenge(); + Ok((c_bar, alpha)) + }) + .transpose()?; + let omega_prime = svk.zk().then(|| transcript.read_scalar()).transpose()?; + let xi_0 = transcript.squeeze_challenge(); + let rounds = iter::repeat_with(|| { + Ok(Round::new( + transcript.read_ec_point()?, + transcript.read_ec_point()?, + transcript.squeeze_challenge(), + )) + }) + .take(svk.domain.k) + .collect::, _>>()?; + let u = transcript.read_ec_point()?; + let c = transcript.read_scalar()?; + Ok(Self { + c_bar_alpha, + omega_prime, + xi_0, + rounds, + u, + c, + }) + } + + pub fn xi(&self) -> Vec { + self.rounds.iter().map(|round| round.xi.clone()).collect() + } + + pub fn xi_inv(&self) -> Vec { + let mut xi_inv = self.xi().into_iter().map(Fraction::one_over).collect_vec(); + L::batch_invert(xi_inv.iter_mut().filter_map(Fraction::denom_mut)); + xi_inv.iter_mut().for_each(Fraction::evaluate); + xi_inv + .into_iter() + .map(|xi_inv| xi_inv.evaluated().clone()) + .collect() + } +} + +#[derive(Clone, Debug)] +pub struct Round +where + C: CurveAffine, + L: Loader, +{ + l: L::LoadedEcPoint, + r: L::LoadedEcPoint, + xi: L::LoadedScalar, +} + +impl Round +where + C: CurveAffine, + L: Loader, +{ + pub fn new(l: L::LoadedEcPoint, r: L::LoadedEcPoint, xi: L::LoadedScalar) -> Self { + Self { l, r, xi } + } +} + +pub fn h_eval>(xi: &[T], z: &T) -> T { + let loader = z.loader(); + let one = loader.load_one(); + loader.product( + &iter::successors(Some(z.clone()), |z| Some(z.square())) + .zip(xi.iter().rev()) + .map(|(z, xi)| z * xi + &one) + .collect_vec() + .iter() + .collect_vec(), + ) +} + +pub fn h_coeffs(xi: &[F], scalar: F) -> Vec { + assert!(!xi.is_empty()); + + let mut coeffs = vec![F::zero(); 1 << xi.len()]; + coeffs[0] = scalar; + + for (len, xi) in xi.iter().rev().enumerate().map(|(i, xi)| (1 << i, xi)) { + let (left, right) = coeffs.split_at_mut(len); + let right = &mut right[0..len]; + right.copy_from_slice(left); + for coeffs in right { + *coeffs *= xi; + } + } + + coeffs +} + +#[cfg(all(test, feature = "system_halo2"))] +mod test { + use crate::{ + pcs::{ + ipa::{self, IpaProvingKey}, + Decider, + }, + util::{arithmetic::Field, msm::Msm, poly::Polynomial}, + }; + use halo2_curves::pasta::pallas; + use halo2_proofs::transcript::{ + Blake2bRead, Blake2bWrite, TranscriptReadBuffer, TranscriptWriterBuffer, + }; + use rand::rngs::OsRng; + + #[test] + fn test_ipa() { + type Ipa = ipa::Ipa; + + let k = 10; + let mut rng = OsRng; + + for zk in [false, true] { + let pk = IpaProvingKey::::rand(k, zk, &mut rng); + let (c, z, v, proof) = { + let p = Polynomial::::rand(pk.domain.n, &mut rng); + let omega = pk.zk().then(|| pallas::Scalar::random(&mut rng)); + let c = pk.commit(&p, omega); + let z = pallas::Scalar::random(&mut rng); + let v = p.evaluate(z); + let mut transcript = Blake2bWrite::init(Vec::new()); + Ipa::create_proof(&pk, &p[..], &z, omega.as_ref(), &mut transcript, &mut rng) + .unwrap(); + (c, z, v, transcript.finalize()) + }; + + let svk = pk.svk(); + let accumulator = { + let mut transcript = Blake2bRead::init(proof.as_slice()); + let proof = Ipa::read_proof(&svk, &mut transcript).unwrap(); + Ipa::succinct_verify(&svk, &Msm::base(&c), &z, &v, &proof).unwrap() + }; + + let dk = pk.dk(); + assert!(Ipa::decide(&dk, accumulator)); + } + } +} diff --git a/src/pcs/ipa/accumulation.rs b/src/pcs/ipa/accumulation.rs new file mode 100644 index 00000000..eeea9efe --- /dev/null +++ b/src/pcs/ipa/accumulation.rs @@ -0,0 +1,291 @@ +use crate::{ + loader::{native::NativeLoader, LoadedScalar, Loader}, + pcs::{ + ipa::{ + h_coeffs, h_eval, Ipa, IpaAccumulator, IpaProof, IpaProvingKey, IpaSuccinctVerifyingKey, + }, + AccumulationScheme, AccumulationSchemeProver, PolynomialCommitmentScheme, + }, + util::{ + arithmetic::{Curve, CurveAffine, Field}, + msm::Msm, + poly::Polynomial, + transcript::{TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, +}; +use rand::Rng; +use std::{array, iter, marker::PhantomData}; + +#[derive(Clone, Debug)] +pub struct IpaAs(PhantomData); + +impl AccumulationScheme for IpaAs +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + type VerifyingKey = IpaSuccinctVerifyingKey; + type Proof = IpaAsProof; + + fn read_proof( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + IpaAsProof::read(vk, instances, transcript) + } + + fn verify( + vk: &Self::VerifyingKey, + instances: &[PCS::Accumulator], + proof: &Self::Proof, + ) -> Result { + let loader = proof.z.loader(); + let s = vk.s.as_ref().map(|s| loader.ec_point_load_const(s)); + + let (u, h) = instances + .iter() + .map(|IpaAccumulator { u, xi }| (u.clone(), h_eval(xi, &proof.z))) + .chain( + proof + .a_b_u + .as_ref() + .map(|(a, b, u)| (u.clone(), a.clone() * &proof.z + b)), + ) + .unzip::<_, _, Vec<_>, Vec<_>>(); + let powers_of_alpha = proof.alpha.powers(u.len()); + + let mut c = powers_of_alpha + .iter() + .zip(u.iter()) + .map(|(power_of_alpha, u)| Msm::::base(u) * power_of_alpha) + .sum::>(); + if let Some(omega) = proof.omega.as_ref() { + c += Msm::base(s.as_ref().unwrap()) * omega; + } + let v = loader.sum_products(&powers_of_alpha.iter().zip(h.iter()).collect_vec()); + + Ipa::::succinct_verify(vk, &c, &proof.z, &v, &proof.ipa) + } +} + +#[derive(Clone, Debug)] +pub struct IpaAsProof +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + a_b_u: Option<(L::LoadedScalar, L::LoadedScalar, L::LoadedEcPoint)>, + omega: Option, + alpha: L::LoadedScalar, + z: L::LoadedScalar, + ipa: IpaProof, + _marker: PhantomData, +} + +impl IpaAsProof +where + C: CurveAffine, + L: Loader, + PCS: PolynomialCommitmentScheme>, +{ + fn read( + vk: &IpaSuccinctVerifyingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + assert!(instances.len() > 1); + + let a_b_u = vk + .zk() + .then(|| { + let a = transcript.read_scalar()?; + let b = transcript.read_scalar()?; + let u = transcript.read_ec_point()?; + Ok((a, b, u)) + }) + .transpose()?; + let omega = vk + .zk() + .then(|| { + let omega = transcript.read_scalar()?; + Ok(omega) + }) + .transpose()?; + + for accumulator in instances { + for xi in accumulator.xi.iter() { + transcript.common_scalar(xi)?; + } + transcript.common_ec_point(&accumulator.u)?; + } + + let alpha = transcript.squeeze_challenge(); + let z = transcript.squeeze_challenge(); + + let ipa = IpaProof::read(vk, transcript)?; + + Ok(Self { + a_b_u, + omega, + alpha, + z, + ipa, + _marker: PhantomData, + }) + } +} + +impl AccumulationSchemeProver for IpaAs +where + C: CurveAffine, + PCS: PolynomialCommitmentScheme>, +{ + type ProvingKey = IpaProvingKey; + + fn create_proof( + pk: &Self::ProvingKey, + instances: &[PCS::Accumulator], + transcript: &mut T, + mut rng: R, + ) -> Result + where + T: TranscriptWrite, + R: Rng, + { + assert!(instances.len() > 1); + + let a_b_u = pk + .zk() + .then(|| { + let [a, b] = array::from_fn(|_| C::Scalar::random(&mut rng)); + let u = (pk.g[1] * a + pk.g[0] * b).to_affine(); + transcript.write_scalar(a)?; + transcript.write_scalar(b)?; + transcript.write_ec_point(u)?; + Ok((a, b, u)) + }) + .transpose()?; + let omega = pk + .zk() + .then(|| { + let omega = C::Scalar::random(&mut rng); + transcript.write_scalar(omega)?; + Ok(omega) + }) + .transpose()?; + + for accumulator in instances { + for xi in accumulator.xi.iter() { + transcript.common_scalar(xi)?; + } + transcript.common_ec_point(&accumulator.u)?; + } + + let alpha = transcript.squeeze_challenge(); + let z = transcript.squeeze_challenge(); + + let (u, h) = instances + .iter() + .map(|IpaAccumulator { u, xi }| (*u, h_coeffs(xi, C::Scalar::one()))) + .chain(a_b_u.map(|(a, b, u)| { + ( + u, + iter::empty() + .chain([b, a]) + .chain(iter::repeat_with(C::Scalar::zero).take(pk.domain.n - 2)) + .collect(), + ) + })) + .unzip::<_, _, Vec<_>, Vec<_>>(); + let powers_of_alpha = alpha.powers(u.len()); + + let h = powers_of_alpha + .into_iter() + .zip(h.into_iter().map(Polynomial::new)) + .map(|(power_of_alpha, h)| h * power_of_alpha) + .sum::>(); + + Ipa::::create_proof(pk, &h.to_vec(), &z, omega.as_ref(), transcript, &mut rng) + } +} + +#[cfg(test)] +mod test { + use crate::{ + pcs::{ + ipa::{self, IpaProvingKey}, + AccumulationScheme, AccumulationSchemeProver, Decider, + }, + util::{arithmetic::Field, msm::Msm, poly::Polynomial, Itertools}, + }; + use halo2_curves::pasta::pallas; + use halo2_proofs::transcript::{ + Blake2bRead, Blake2bWrite, TranscriptReadBuffer, TranscriptWriterBuffer, + }; + use rand::rngs::OsRng; + use std::iter; + + #[test] + fn test_ipa_as() { + type Ipa = ipa::Ipa; + type IpaAs = ipa::IpaAs; + + let k = 10; + let zk = true; + let mut rng = OsRng; + + let pk = IpaProvingKey::::rand(k, zk, &mut rng); + let accumulators = iter::repeat_with(|| { + let (c, z, v, proof) = { + let p = Polynomial::::rand(pk.domain.n, &mut rng); + let omega = pk.zk().then(|| pallas::Scalar::random(&mut rng)); + let c = pk.commit(&p, omega); + let z = pallas::Scalar::random(&mut rng); + let v = p.evaluate(z); + let mut transcript = Blake2bWrite::init(Vec::new()); + Ipa::create_proof(&pk, &p[..], &z, omega.as_ref(), &mut transcript, &mut rng) + .unwrap(); + (c, z, v, transcript.finalize()) + }; + + let svk = pk.svk(); + let accumulator = { + let mut transcript = Blake2bRead::init(proof.as_slice()); + let proof = Ipa::read_proof(&svk, &mut transcript).unwrap(); + Ipa::succinct_verify(&svk, &Msm::base(&c), &z, &v, &proof).unwrap() + }; + + accumulator + }) + .take(10) + .collect_vec(); + + let proof = { + let apk = pk.clone(); + let mut transcript = Blake2bWrite::init(Vec::new()); + IpaAs::create_proof(&apk, &accumulators, &mut transcript, &mut rng).unwrap(); + transcript.finalize() + }; + + let accumulator = { + let avk = pk.svk(); + let mut transcript = Blake2bRead::init(proof.as_slice()); + let proof = IpaAs::read_proof(&avk, &accumulators, &mut transcript).unwrap(); + IpaAs::verify(&avk, &accumulators, &proof).unwrap() + }; + + let dk = pk.dk(); + assert!(Ipa::decide(&dk, accumulator)); + } +} diff --git a/src/pcs/ipa/accumulator.rs b/src/pcs/ipa/accumulator.rs new file mode 100644 index 00000000..27d9d5c7 --- /dev/null +++ b/src/pcs/ipa/accumulator.rs @@ -0,0 +1,21 @@ +use crate::{loader::Loader, util::arithmetic::CurveAffine}; + +#[derive(Clone, Debug)] +pub struct IpaAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub xi: Vec, + pub u: L::LoadedEcPoint, +} + +impl IpaAccumulator +where + C: CurveAffine, + L: Loader, +{ + pub fn new(xi: Vec, u: L::LoadedEcPoint) -> Self { + Self { xi, u } + } +} diff --git a/src/pcs/ipa/decider.rs b/src/pcs/ipa/decider.rs new file mode 100644 index 00000000..2cf8c6cc --- /dev/null +++ b/src/pcs/ipa/decider.rs @@ -0,0 +1,57 @@ +#[derive(Clone, Debug)] +pub struct IpaDecidingKey { + pub g: Vec, +} + +impl IpaDecidingKey { + pub fn new(g: Vec) -> Self { + Self { g } + } +} + +impl From> for IpaDecidingKey { + fn from(g: Vec) -> IpaDecidingKey { + IpaDecidingKey::new(g) + } +} + +mod native { + use crate::{ + loader::native::NativeLoader, + pcs::{ + ipa::{h_coeffs, Ipa, IpaAccumulator, IpaDecidingKey}, + Decider, + }, + util::{ + arithmetic::{Curve, CurveAffine, Field}, + msm::multi_scalar_multiplication, + }, + }; + use std::fmt::Debug; + + impl Decider for Ipa + where + C: CurveAffine, + MOS: Clone + Debug, + { + type DecidingKey = IpaDecidingKey; + type Output = bool; + + fn decide( + dk: &Self::DecidingKey, + IpaAccumulator { u, xi }: IpaAccumulator, + ) -> bool { + let h = h_coeffs(&xi, C::Scalar::one()); + u == multi_scalar_multiplication(&h, &dk.g).to_affine() + } + + fn decide_all( + dk: &Self::DecidingKey, + accumulators: Vec>, + ) -> bool { + !accumulators + .into_iter() + .any(|accumulator| !Self::decide(dk, accumulator)) + } + } +} diff --git a/src/pcs/ipa/multiopen.rs b/src/pcs/ipa/multiopen.rs new file mode 100644 index 00000000..9f685e76 --- /dev/null +++ b/src/pcs/ipa/multiopen.rs @@ -0,0 +1,3 @@ +mod bgh19; + +pub use bgh19::{Bgh19, Bgh19Proof, Bgh19SuccinctVerifyingKey}; diff --git a/src/pcs/ipa/multiopen/bgh19.rs b/src/pcs/ipa/multiopen/bgh19.rs new file mode 100644 index 00000000..29d291ad --- /dev/null +++ b/src/pcs/ipa/multiopen/bgh19.rs @@ -0,0 +1,417 @@ +use crate::{ + loader::{LoadedScalar, Loader, ScalarLoader}, + pcs::{ + ipa::{Ipa, IpaProof, IpaSuccinctVerifyingKey, Round}, + MultiOpenScheme, Query, + }, + util::{ + arithmetic::{ilog2, CurveAffine, Domain, FieldExt, Fraction}, + msm::Msm, + transcript::TranscriptRead, + Itertools, + }, + Error, +}; +use std::{ + collections::{BTreeMap, BTreeSet}, + iter, + marker::PhantomData, +}; + +#[derive(Clone, Debug)] +pub struct Bgh19; + +impl MultiOpenScheme for Ipa +where + C: CurveAffine, + L: Loader, +{ + type SuccinctVerifyingKey = Bgh19SuccinctVerifyingKey; + type Proof = Bgh19Proof; + + fn read_proof( + svk: &Self::SuccinctVerifyingKey, + queries: &[Query], + transcript: &mut T, + ) -> Result + where + T: TranscriptRead, + { + Bgh19Proof::read(svk, queries, transcript) + } + + fn succinct_verify( + svk: &Self::SuccinctVerifyingKey, + commitments: &[Msm], + x: &L::LoadedScalar, + queries: &[Query], + proof: &Self::Proof, + ) -> Result { + let loader = x.loader(); + let g = loader.ec_point_load_const(&svk.g); + + // Multiopen + let sets = query_sets(queries); + let p = { + let coeffs = query_set_coeffs(&sets, x, &proof.x_3); + + let powers_of_x_1 = proof + .x_1 + .powers(sets.iter().map(|set| set.polys.len()).max().unwrap()); + let f_eval = { + let powers_of_x_2 = proof.x_2.powers(sets.len()); + let f_evals = sets + .iter() + .zip(coeffs.iter()) + .zip(proof.q_evals.iter()) + .map(|((set, coeff), q_eval)| set.f_eval(coeff, q_eval, &powers_of_x_1)) + .collect_vec(); + x.loader() + .sum_products(&powers_of_x_2.iter().zip(f_evals.iter().rev()).collect_vec()) + }; + let msms = sets + .iter() + .zip(proof.q_evals.iter()) + .map(|(set, q_eval)| set.msm(commitments, q_eval, &powers_of_x_1)); + + let (mut msm, constant) = iter::once(Msm::base(&proof.f) - Msm::constant(f_eval)) + .chain(msms) + .zip(proof.x_4.powers(sets.len() + 1).into_iter().rev()) + .map(|(msm, power_of_x_4)| msm * &power_of_x_4) + .sum::>() + .split(); + if let Some(constant) = constant { + msm += Msm::base(&g) * &constant; + } + msm + }; + + // IPA + Ipa::::succinct_verify(&svk.ipa, &p, &proof.x_3, &loader.load_zero(), &proof.ipa) + } +} + +#[derive(Clone, Debug)] +pub struct Bgh19SuccinctVerifyingKey { + g: C, + ipa: IpaSuccinctVerifyingKey, +} + +impl Bgh19SuccinctVerifyingKey { + pub fn new(domain: Domain, g: C, w: C, u: C) -> Self { + Self { + g, + ipa: IpaSuccinctVerifyingKey::new(domain, u, Some(w)), + } + } +} + +#[derive(Clone, Debug)] +pub struct Bgh19Proof +where + C: CurveAffine, + L: Loader, +{ + // Multiopen + x_1: L::LoadedScalar, + x_2: L::LoadedScalar, + f: L::LoadedEcPoint, + x_3: L::LoadedScalar, + q_evals: Vec, + x_4: L::LoadedScalar, + // IPA + ipa: IpaProof, +} + +impl Bgh19Proof +where + C: CurveAffine, + L: Loader, +{ + fn read>( + svk: &Bgh19SuccinctVerifyingKey, + queries: &[Query], + transcript: &mut T, + ) -> Result { + // Multiopen + let x_1 = transcript.squeeze_challenge(); + let x_2 = transcript.squeeze_challenge(); + let f = transcript.read_ec_point()?; + let x_3 = transcript.squeeze_challenge(); + let q_evals = transcript.read_n_scalars(query_sets(queries).len())?; + let x_4 = transcript.squeeze_challenge(); + // IPA + let s = transcript.read_ec_point()?; + let xi = transcript.squeeze_challenge(); + let z = transcript.squeeze_challenge(); + let rounds = iter::repeat_with(|| { + Ok(Round::new( + transcript.read_ec_point()?, + transcript.read_ec_point()?, + transcript.squeeze_challenge(), + )) + }) + .take(svk.ipa.domain.k) + .collect::, _>>()?; + let c = transcript.read_scalar()?; + let blind = transcript.read_scalar()?; + let g = transcript.read_ec_point()?; + Ok(Bgh19Proof { + x_1, + x_2, + f, + x_3, + q_evals, + x_4, + ipa: IpaProof::new(Some((s, xi)), Some(blind), z, rounds, g, c), + }) + } +} + +fn query_sets(queries: &[Query]) -> Vec> +where + F: FieldExt, + T: Clone, +{ + let poly_shifts = queries.iter().fold( + Vec::<(usize, Vec, Vec<&T>)>::new(), + |mut poly_shifts, query| { + if let Some(pos) = poly_shifts + .iter() + .position(|(poly, _, _)| *poly == query.poly) + { + let (_, shifts, evals) = &mut poly_shifts[pos]; + if !shifts.contains(&query.shift) { + shifts.push(query.shift); + evals.push(&query.eval); + } + } else { + poly_shifts.push((query.poly, vec![query.shift], vec![&query.eval])); + } + poly_shifts + }, + ); + + poly_shifts.into_iter().fold( + Vec::>::new(), + |mut sets, (poly, shifts, evals)| { + if let Some(pos) = sets.iter().position(|set| { + BTreeSet::from_iter(set.shifts.iter()) == BTreeSet::from_iter(shifts.iter()) + }) { + let set = &mut sets[pos]; + if !set.polys.contains(&poly) { + set.polys.push(poly); + set.evals.push( + set.shifts + .iter() + .map(|lhs| { + let idx = shifts.iter().position(|rhs| lhs == rhs).unwrap(); + evals[idx] + }) + .collect(), + ); + } + } else { + let set = QuerySet { + shifts, + polys: vec![poly], + evals: vec![evals], + }; + sets.push(set); + } + sets + }, + ) +} + +fn query_set_coeffs(sets: &[QuerySet], x: &T, x_3: &T) -> Vec> +where + F: FieldExt, + T: LoadedScalar, +{ + let loader = x.loader(); + let superset = sets + .iter() + .flat_map(|set| set.shifts.clone()) + .sorted() + .dedup(); + + let size = 2.max( + ilog2((sets.iter().map(|set| set.shifts.len()).max().unwrap() - 1).next_power_of_two()) + 1, + ); + let powers_of_x = x.powers(size); + let x_3_minus_x_shift_i = BTreeMap::from_iter( + superset.map(|shift| (shift, x_3.clone() - x.clone() * loader.load_const(&shift))), + ); + + let mut coeffs = sets + .iter() + .map(|set| QuerySetCoeff::new(&set.shifts, &powers_of_x, x_3, &x_3_minus_x_shift_i)) + .collect_vec(); + + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + coeffs.iter_mut().for_each(QuerySetCoeff::evaluate); + + coeffs +} + +#[derive(Clone, Debug)] +struct QuerySet<'a, F, T> { + shifts: Vec, + polys: Vec, + evals: Vec>, +} + +impl<'a, F, T> QuerySet<'a, F, T> +where + F: FieldExt, + T: LoadedScalar, +{ + fn msm>( + &self, + commitments: &[Msm<'a, C, L>], + q_eval: &T, + powers_of_x_1: &[T], + ) -> Msm { + self.polys + .iter() + .rev() + .zip(powers_of_x_1) + .map(|(poly, power_of_x_1)| commitments[*poly].clone() * power_of_x_1) + .sum::>() + - Msm::constant(q_eval.clone()) + } + + fn f_eval(&self, coeff: &QuerySetCoeff, q_eval: &T, powers_of_x_1: &[T]) -> T { + let loader = q_eval.loader(); + let r_eval = { + let r_evals = self + .evals + .iter() + .map(|evals| { + loader.sum_products( + &coeff + .eval_coeffs + .iter() + .zip(evals.iter()) + .map(|(coeff, eval)| (coeff.evaluated(), *eval)) + .collect_vec(), + ) * coeff.r_eval_coeff.as_ref().unwrap().evaluated() + }) + .collect_vec(); + loader.sum_products(&r_evals.iter().rev().zip(powers_of_x_1).collect_vec()) + }; + + (q_eval.clone() - r_eval) * coeff.f_eval_coeff.evaluated() + } +} + +#[derive(Clone, Debug)] +struct QuerySetCoeff { + eval_coeffs: Vec>, + r_eval_coeff: Option>, + f_eval_coeff: Fraction, + _marker: PhantomData, +} + +impl QuerySetCoeff +where + F: FieldExt, + T: LoadedScalar, +{ + fn new(shifts: &[F], powers_of_x: &[T], x_3: &T, x_3_minus_x_shift_i: &BTreeMap) -> Self { + let loader = x_3.loader(); + let normalized_ell_primes = shifts + .iter() + .enumerate() + .map(|(j, shift_j)| { + shifts + .iter() + .enumerate() + .filter(|&(i, _)| i != j) + .map(|(_, shift_i)| (*shift_j - shift_i)) + .reduce(|acc, value| acc * value) + .unwrap_or_else(|| F::one()) + }) + .collect_vec(); + + let x = &powers_of_x[1].clone(); + let x_pow_k_minus_one = { + let k_minus_one = shifts.len() - 1; + powers_of_x + .iter() + .enumerate() + .skip(1) + .filter_map(|(i, power_of_x)| { + (k_minus_one & (1 << i) == 1).then(|| power_of_x.clone()) + }) + .reduce(|acc, value| acc * value) + .unwrap_or_else(|| loader.load_one()) + }; + + let barycentric_weights = shifts + .iter() + .zip(normalized_ell_primes.iter()) + .map(|(shift, normalized_ell_prime)| { + loader.sum_products_with_coeff(&[ + (*normalized_ell_prime, &x_pow_k_minus_one, x_3), + (-(*normalized_ell_prime * shift), &x_pow_k_minus_one, x), + ]) + }) + .map(Fraction::one_over) + .collect_vec(); + + let f_eval_coeff = Fraction::one_over( + loader.product( + &shifts + .iter() + .map(|shift| x_3_minus_x_shift_i.get(shift).unwrap()) + .collect_vec(), + ), + ); + + Self { + eval_coeffs: barycentric_weights, + r_eval_coeff: None, + f_eval_coeff, + _marker: PhantomData, + } + } + + fn denoms(&mut self) -> impl IntoIterator { + if self.eval_coeffs.first().unwrap().denom().is_some() { + return self + .eval_coeffs + .iter_mut() + .chain(Some(&mut self.f_eval_coeff)) + .filter_map(Fraction::denom_mut) + .collect_vec(); + } + + if self.r_eval_coeff.is_none() { + self.eval_coeffs + .iter_mut() + .chain(Some(&mut self.f_eval_coeff)) + .for_each(Fraction::evaluate); + + let loader = self.f_eval_coeff.evaluated().loader(); + let barycentric_weights_sum = loader.sum( + &self + .eval_coeffs + .iter() + .map(Fraction::evaluated) + .collect_vec(), + ); + self.r_eval_coeff = Some(Fraction::one_over(barycentric_weights_sum)); + + return vec![self.r_eval_coeff.as_mut().unwrap().denom_mut().unwrap()]; + } + + unreachable!() + } + + fn evaluate(&mut self) { + self.r_eval_coeff.as_mut().unwrap().evaluate(); + } +} diff --git a/src/system/halo2.rs b/src/system/halo2.rs index a49fa6ef..7c26387d 100644 --- a/src/system/halo2.rs +++ b/src/system/halo2.rs @@ -16,6 +16,7 @@ use halo2_proofs::{ use num_integer::Integer; use std::{io, iter, mem::size_of}; +pub mod strategy; pub mod transcript; #[cfg(test)] diff --git a/src/system/halo2/strategy.rs b/src/system/halo2/strategy.rs new file mode 100644 index 00000000..de66f8e3 --- /dev/null +++ b/src/system/halo2/strategy.rs @@ -0,0 +1,53 @@ +pub mod ipa { + use crate::util::arithmetic::CurveAffine; + use halo2_proofs::{ + plonk::Error, + poly::{ + commitment::MSM, + ipa::{ + commitment::{IPACommitmentScheme, ParamsIPA}, + msm::MSMIPA, + multiopen::VerifierIPA, + strategy::GuardIPA, + }, + VerificationStrategy, + }, + }; + + #[derive(Clone, Debug)] + pub struct SingleStrategy<'a, C: CurveAffine> { + msm: MSMIPA<'a, C>, + } + + impl<'a, C: CurveAffine> VerificationStrategy<'a, IPACommitmentScheme, VerifierIPA<'a, C>> + for SingleStrategy<'a, C> + { + type Output = C; + + fn new(params: &'a ParamsIPA) -> Self { + SingleStrategy { + msm: MSMIPA::new(params), + } + } + + fn process( + self, + f: impl FnOnce(MSMIPA<'a, C>) -> Result, Error>, + ) -> Result { + let guard = f(self.msm)?; + + let g = guard.compute_g(); + let (msm, _) = guard.use_g(g); + + if msm.check() { + Ok(g) + } else { + Err(Error::ConstraintSystemFailure) + } + } + + fn finalize(self) -> bool { + unreachable!() + } + } +} diff --git a/src/system/halo2/test.rs b/src/system/halo2/test.rs index 9cd4a2fc..1ec03306 100644 --- a/src/system/halo2/test.rs +++ b/src/system/halo2/test.rs @@ -12,6 +12,7 @@ use rand_chacha::rand_core::RngCore; use std::{fs, io::Cursor}; mod circuit; +mod ipa; mod kzg; pub use circuit::{ diff --git a/src/system/halo2/test/ipa.rs b/src/system/halo2/test/ipa.rs new file mode 100644 index 00000000..07fd6efd --- /dev/null +++ b/src/system/halo2/test/ipa.rs @@ -0,0 +1,143 @@ +use crate::util::arithmetic::CurveAffine; +use halo2_proofs::poly::{ + commitment::{Params, ParamsProver}, + ipa::commitment::ParamsIPA, +}; +use std::mem::size_of; + +mod native; + +pub const TESTDATA_DIR: &str = "./src/system/halo2/test/ipa/testdata"; + +pub fn setup(k: u32) -> ParamsIPA { + ParamsIPA::new(k) +} + +pub fn w_u() -> (C, C) { + let mut buf = Vec::new(); + setup::(1).write(&mut buf).unwrap(); + + let repr = C::Repr::default(); + let repr_len = repr.as_ref().len(); + let offset = size_of::() + 4 * repr_len; + + let [w, u] = [offset, offset + repr_len].map(|offset| { + let mut repr = C::Repr::default(); + repr.as_mut() + .copy_from_slice(&buf[offset..offset + repr_len]); + C::from_bytes(&repr).unwrap() + }); + + (w, u) +} + +macro_rules! halo2_ipa_config { + ($zk:expr, $num_proof:expr) => { + $crate::system::halo2::Config::ipa() + .set_zk($zk) + .with_num_proof($num_proof) + }; + ($zk:expr, $num_proof:expr, $accumulator_indices:expr) => { + $crate::system::halo2::Config::ipa() + .set_zk($zk) + .with_num_proof($num_proof) + .with_accumulator_indices($accumulator_indices) + }; +} + +macro_rules! halo2_ipa_prepare { + ($dir:expr, $curve:path, $k:expr, $config:expr, $create_circuit:expr) => {{ + use $crate::system::halo2::test::{halo2_prepare, ipa::setup}; + + halo2_prepare!($dir, $k, setup::<$curve>, $config, $create_circuit) + }}; + (pallas::Affine, $k:expr, $config:expr, $create_circuit:expr) => {{ + use halo2_curves::pasta::pallas; + use $crate::system::halo2::test::ipa::TESTDATA_DIR; + + halo2_ipa_prepare!( + &format!("{TESTDATA_DIR}/pallas"), + pallas::Affine, + $k, + $config, + $create_circuit + ) + }}; + (vesta::Affine, $k:expr, $config:expr, $create_circuit:expr) => {{ + use halo2_curves::pasta::vesta; + use $crate::system::halo2::test::ipa::TESTDATA_DIR; + + halo2_ipa_prepare!( + &format!("{TESTDATA_DIR}/vesta"), + vesta::Affine, + $k, + $config, + $create_circuit + ) + }}; +} + +macro_rules! halo2_ipa_create_snark { + ( + $prover:ty, + $verifier:ty, + $transcript_read:ty, + $transcript_write:ty, + $encoded_challenge:ty, + $params:expr, + $pk:expr, + $protocol:expr, + $circuits:expr + ) => {{ + use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme; + use $crate::{ + system::halo2::{strategy::ipa::SingleStrategy, test::halo2_create_snark}, + util::arithmetic::GroupEncoding, + }; + + halo2_create_snark!( + IPACommitmentScheme<_>, + $prover, + $verifier, + SingleStrategy<_>, + $transcript_read, + $transcript_write, + $encoded_challenge, + |proof, g| { [proof, g.to_bytes().as_ref().to_vec()].concat() }, + $params, + $pk, + $protocol, + $circuits + ) + }}; +} + +macro_rules! halo2_ipa_native_verify { + ( + $plonk_verifier:ty, + $params:expr, + $protocol:expr, + $instances:expr, + $transcript:expr + ) => {{ + use $crate::{ + pcs::ipa::{Bgh19SuccinctVerifyingKey, IpaDecidingKey}, + system::halo2::test::{halo2_native_verify, ipa::w_u}, + }; + + let (w, u) = w_u(); + halo2_native_verify!( + $plonk_verifier, + $params, + $protocol, + $instances, + $transcript, + &Bgh19SuccinctVerifyingKey::new($protocol.domain.clone(), $params.get_g()[0], w, u), + &IpaDecidingKey::new($params.get_g().to_vec()) + ) + }}; +} + +pub(crate) use { + halo2_ipa_config, halo2_ipa_create_snark, halo2_ipa_native_verify, halo2_ipa_prepare, +}; diff --git a/src/system/halo2/test/ipa/native.rs b/src/system/halo2/test/ipa/native.rs new file mode 100644 index 00000000..7d9e09bb --- /dev/null +++ b/src/system/halo2/test/ipa/native.rs @@ -0,0 +1,59 @@ +use crate::{ + pcs::ipa::{Bgh19, Ipa}, + system::halo2::test::ipa::{ + halo2_ipa_config, halo2_ipa_create_snark, halo2_ipa_native_verify, halo2_ipa_prepare, + }, + system::halo2::test::StandardPlonk, + verifier::Plonk, +}; +use halo2_curves::pasta::pallas; +use halo2_proofs::{ + poly::ipa::multiopen::{ProverIPA, VerifierIPA}, + transcript::{Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer}, +}; +use paste::paste; +use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + +macro_rules! test { + (@ $name:ident, $k:expr, $config:expr, $create_cirucit:expr, $prover:ty, $verifier:ty, $plonk_verifier:ty) => { + paste! { + #[test] + fn []() { + let (params, pk, protocol, circuits) = halo2_ipa_prepare!( + pallas::Affine, + $k, + $config, + $create_cirucit + ); + let snark = halo2_ipa_create_snark!( + $prover, + $verifier, + Blake2bWrite<_, _, _>, + Blake2bRead<_, _, _>, + Challenge255<_>, + ¶ms, + &pk, + &protocol, + &circuits + ); + halo2_ipa_native_verify!( + $plonk_verifier, + params, + &snark.protocol, + &snark.instances, + &mut Blake2bRead::<_, pallas::Affine, _>::init(snark.proof.as_slice()) + ); + } + } + }; + ($name:ident, $k:expr, $config:expr, $create_cirucit:expr) => { + test!(@ $name, $k, $config, $create_cirucit, ProverIPA, VerifierIPA, Plonk::>); + } +} + +test!( + zk_standard_plonk_rand, + 9, + halo2_ipa_config!(true, 1), + StandardPlonk::rand(ChaCha20Rng::from_seed(Default::default())) +); diff --git a/src/system/halo2/test/kzg/halo2.rs b/src/system/halo2/test/kzg/halo2.rs index 11a6046f..314af5c7 100644 --- a/src/system/halo2/test/kzg/halo2.rs +++ b/src/system/halo2/test/kzg/halo2.rs @@ -1,6 +1,6 @@ use crate::{ - loader, loader::{ + self, halo2::test::{Snark, SnarkWitness}, native::NativeLoader, }, @@ -30,8 +30,7 @@ use crate::{ use halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use halo2_proofs::{ circuit::{floor_planner::V1, Layouter, Value}, - plonk, - plonk::{Circuit, Error}, + plonk::{Circuit, ConstraintSystem, Error}, poly::{ commitment::ParamsProver, kzg::{ @@ -263,7 +262,7 @@ impl Circuit for Accumulation { } } - fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + fn configure(meta: &mut ConstraintSystem) -> Self::Config { MainGateWithRangeConfig::configure( meta, vec![BITS / LIMBS], @@ -275,7 +274,7 @@ impl Circuit for Accumulation { &self, config: Self::Config, mut layouter: impl Layouter, - ) -> Result<(), plonk::Error> { + ) -> Result<(), Error> { let main_gate = config.main_gate(); let range_chip = config.range_chip(); diff --git a/src/util.rs b/src/util.rs index 3d5d0d79..b42db61c 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,7 +1,47 @@ pub mod arithmetic; pub mod hash; pub mod msm; +pub mod poly; pub mod protocol; pub mod transcript; pub(crate) use itertools::Itertools; + +#[cfg(feature = "parallel")] +pub(crate) use rayon::current_num_threads; + +pub fn parallelize_iter(iter: I, f: F) +where + I: Send + Iterator, + T: Send, + F: Fn(T) + Send + Sync + Clone, +{ + #[cfg(feature = "parallel")] + rayon::scope(|scope| { + for item in iter { + let f = f.clone(); + scope.spawn(move |_| f(item)); + } + }); + #[cfg(not(feature = "parallel"))] + iter.for_each(f); +} + +pub fn parallelize(v: &mut [T], f: F) +where + T: Send, + F: Fn((&mut [T], usize)) + Send + Sync + Clone, +{ + #[cfg(feature = "parallel")] + { + let num_threads = current_num_threads(); + let chunk_size = v.len() / num_threads; + if chunk_size < num_threads { + f((v, 0)); + } else { + parallelize_iter(v.chunks_mut(chunk_size).zip((0..).step_by(chunk_size)), f); + } + } + #[cfg(not(feature = "parallel"))] + f((v, 0)); +} diff --git a/src/util/arithmetic.rs b/src/util/arithmetic.rs index 02d6da32..dacd2443 100644 --- a/src/util/arithmetic.rs +++ b/src/util/arithmetic.rs @@ -250,3 +250,15 @@ pub fn fe_to_limbs(scalar: F) -> impl Iterator { + iter::successors(Some(F::one()), move |power| Some(scalar * power)) +} + +pub fn inner_product(lhs: &[F], rhs: &[F]) -> F { + lhs.iter() + .zip_eq(rhs.iter()) + .map(|(lhs, rhs)| *lhs * rhs) + .reduce(|acc, product| acc + product) + .unwrap_or_default() +} diff --git a/src/util/msm.rs b/src/util/msm.rs index f35ea197..014a29e8 100644 --- a/src/util/msm.rs +++ b/src/util/msm.rs @@ -1,10 +1,15 @@ use crate::{ loader::{LoadedEcPoint, Loader}, - util::{arithmetic::CurveAffine, Itertools}, + util::{ + arithmetic::{CurveAffine, Group, PrimeField}, + Itertools, + }, }; +use num_integer::Integer; use std::{ default::Default, iter::{self, Sum}, + mem::size_of, ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, }; @@ -209,3 +214,119 @@ where iter.reduce(|acc, item| acc + item).unwrap_or_default() } } + +#[derive(Clone, Copy)] +enum Bucket { + None, + Affine(C), + Projective(C::Curve), +} + +impl Bucket { + fn add_assign(&mut self, rhs: &C) { + *self = match *self { + Bucket::None => Bucket::Affine(*rhs), + Bucket::Affine(lhs) => Bucket::Projective(lhs + *rhs), + Bucket::Projective(mut lhs) => { + lhs += *rhs; + Bucket::Projective(lhs) + } + } + } + + fn add(self, mut rhs: C::Curve) -> C::Curve { + match self { + Bucket::None => rhs, + Bucket::Affine(lhs) => { + rhs += lhs; + rhs + } + Bucket::Projective(lhs) => lhs + rhs, + } + } +} + +fn multi_scalar_multiplication_serial( + scalars: &[C::Scalar], + bases: &[C], + result: &mut C::Curve, +) { + let scalars = scalars.iter().map(|scalar| scalar.to_repr()).collect_vec(); + let num_bytes = scalars[0].as_ref().len(); + let num_bits = 8 * num_bytes; + + let window_size = (scalars.len() as f64).ln().ceil() as usize + 2; + let num_buckets = (1 << window_size) - 1; + + let windowed_scalar = |idx: usize, bytes: &::Repr| { + let skip_bits = idx * window_size; + let skip_bytes = skip_bits / 8; + + let mut value = [0; size_of::()]; + for (dst, src) in value.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) { + *dst = *src; + } + + (usize::from_le_bytes(value) >> (skip_bits - (skip_bytes * 8))) & num_buckets + }; + + let num_window = Integer::div_ceil(&num_bits, &window_size); + for idx in (0..num_window).rev() { + for _ in 0..window_size { + *result = result.double(); + } + + let mut buckets = vec![Bucket::None; num_buckets]; + + for (scalar, base) in scalars.iter().zip(bases.iter()) { + let scalar = windowed_scalar(idx, scalar); + if scalar != 0 { + buckets[scalar - 1].add_assign(base); + } + } + + let mut running_sum = C::Curve::identity(); + for bucket in buckets.into_iter().rev() { + running_sum = bucket.add(running_sum); + *result += &running_sum; + } + } +} + +// Copy from https://github.com/zcash/halo2/blob/main/halo2_proofs/src/arithmetic.rs +pub fn multi_scalar_multiplication(scalars: &[C::Scalar], bases: &[C]) -> C::Curve { + assert_eq!(scalars.len(), bases.len()); + + #[cfg(feature = "parallel")] + { + use crate::util::{current_num_threads, parallelize_iter}; + + let num_threads = current_num_threads(); + if scalars.len() < num_threads { + let mut result = C::Curve::identity(); + multi_scalar_multiplication_serial(scalars, bases, &mut result); + return result; + } + + let chunk_size = Integer::div_ceil(&scalars.len(), &num_threads); + let mut results = vec![C::Curve::identity(); num_threads]; + parallelize_iter( + scalars + .chunks(chunk_size) + .zip(bases.chunks(chunk_size)) + .zip(results.iter_mut()), + |((scalars, bases), result)| { + multi_scalar_multiplication_serial(scalars, bases, result); + }, + ); + results + .iter() + .fold(C::Curve::identity(), |acc, result| acc + result) + } + #[cfg(not(feature = "parallel"))] + { + let mut result = C::Curve::identity(); + multi_scalar_multiplication_serial(scalars, bases, &mut result); + result + } +} diff --git a/src/util/poly.rs b/src/util/poly.rs new file mode 100644 index 00000000..ea120b33 --- /dev/null +++ b/src/util/poly.rs @@ -0,0 +1,175 @@ +use crate::util::{arithmetic::Field, parallelize}; +use rand::Rng; +use std::{ + iter::{self, Sum}, + ops::{ + Add, Index, IndexMut, Mul, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, + RangeToInclusive, Sub, + }, +}; + +#[derive(Clone, Debug)] +pub struct Polynomial(Vec); + +impl Polynomial { + pub fn new(inner: Vec) -> Self { + Self(inner) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub fn iter_mut(&mut self) -> impl Iterator { + self.0.iter_mut() + } + + pub fn to_vec(self) -> Vec { + self.0 + } +} + +impl Polynomial { + pub fn rand(n: usize, mut rng: R) -> Self { + Self::new(iter::repeat_with(|| F::random(&mut rng)).take(n).collect()) + } + + pub fn evaluate(&self, x: F) -> F { + let evaluate_serial = |coeffs: &[F]| { + coeffs + .iter() + .rev() + .fold(F::zero(), |acc, coeff| acc * x + coeff) + }; + + #[cfg(feature = "parallel")] + { + use crate::util::{arithmetic::powers, current_num_threads, parallelize_iter}; + use num_integer::Integer; + + let num_threads = current_num_threads(); + if self.len() * 2 < num_threads { + return evaluate_serial(&self.0); + } + + let chunk_size = Integer::div_ceil(&self.len(), &num_threads); + let mut results = vec![F::zero(); num_threads]; + parallelize_iter( + results + .iter_mut() + .zip(self.0.chunks(chunk_size)) + .zip(powers(x.pow_vartime(&[chunk_size as u64, 0, 0, 0]))), + |((result, coeffs), scalar)| *result = evaluate_serial(coeffs) * scalar, + ); + results.iter().fold(F::zero(), |acc, result| acc + result) + } + #[cfg(not(feature = "parallel"))] + evaluate_serial(&self.0) + } +} + +impl<'a, F: Field> Add<&'a Polynomial> for Polynomial { + type Output = Polynomial; + + fn add(mut self, rhs: &'a Polynomial) -> Polynomial { + parallelize(&mut self.0, |(lhs, start)| { + for (lhs, rhs) in lhs.iter_mut().zip(rhs.0[start..].iter()) { + *lhs += *rhs; + } + }); + self + } +} + +impl<'a, F: Field> Sub<&'a Polynomial> for Polynomial { + type Output = Polynomial; + + fn sub(mut self, rhs: &'a Polynomial) -> Polynomial { + parallelize(&mut self.0, |(lhs, start)| { + for (lhs, rhs) in lhs.iter_mut().zip(rhs.0[start..].iter()) { + *lhs -= *rhs; + } + }); + self + } +} + +impl Sub for Polynomial { + type Output = Polynomial; + + fn sub(mut self, rhs: F) -> Polynomial { + self.0[0] -= rhs; + self + } +} + +impl Add for Polynomial { + type Output = Polynomial; + + fn add(mut self, rhs: F) -> Polynomial { + self.0[0] += rhs; + self + } +} + +impl Mul for Polynomial { + type Output = Polynomial; + + fn mul(mut self, rhs: F) -> Polynomial { + if rhs == F::zero() { + return Polynomial::new(vec![F::zero(); self.len()]); + } + if rhs == F::one() { + return self; + } + parallelize(&mut self.0, |(lhs, _)| { + for lhs in lhs.iter_mut() { + *lhs *= rhs; + } + }); + self + } +} + +impl Sum for Polynomial { + fn sum>(iter: I) -> Self { + iter.reduce(|acc, item| acc + &item).unwrap() + } +} + +macro_rules! impl_index { + ($($range:ty => $output:ty,)*) => { + $( + impl Index<$range> for Polynomial { + type Output = $output; + + fn index(&self, index: $range) -> &$output { + self.0.index(index) + } + } + impl IndexMut<$range> for Polynomial { + fn index_mut(&mut self, index: $range) -> &mut $output { + self.0.index_mut(index) + } + } + )* + }; +} + +impl_index!( + usize => F, + Range => [F], + RangeFrom => [F], + RangeFull => [F], + RangeInclusive => [F], + RangeTo => [F], + RangeToInclusive => [F], +); From 75f7532a9c423d209dec179ba8f9ffdbf5513ec2 Mon Sep 17 00:00:00 2001 From: han0110 Date: Fri, 28 Oct 2022 18:48:10 +0800 Subject: [PATCH 19/28] feat: add example `recursion` --- Cargo.toml | 4 + examples/recursion.rs | 860 +++++++++++++++++++++++++++++++++++++ src/loader/halo2/loader.rs | 7 + 3 files changed, 871 insertions(+) create mode 100644 examples/recursion.rs diff --git a/Cargo.toml b/Cargo.toml index 44dcd60b..9251d05e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,3 +58,7 @@ required-features = ["loader_evm", "system_halo2"] [[example]] name = "evm-verifier-with-accumulator" required-features = ["loader_halo2", "loader_evm", "system_halo2"] + +[[example]] +name = "recursion" +required-features = ["loader_halo2", "system_halo2"] diff --git a/examples/recursion.rs b/examples/recursion.rs new file mode 100644 index 00000000..a814b4b1 --- /dev/null +++ b/examples/recursion.rs @@ -0,0 +1,860 @@ +#![allow(clippy::type_complexity)] + +use common::*; +use halo2_curves::{ + bn256::{Bn256, Fq, Fr, G1Affine}, + group::ff::Field, + FieldExt, +}; +use halo2_proofs::{ + circuit::{AssignedCell, Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + plonk::{ + create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ConstraintSystem, Error, + ProvingKey, Selector, VerifyingKey, + }, + poly::{ + commitment::ParamsProver, + kzg::{ + commitment::ParamsKZG, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + Rotation, VerificationStrategy, + }, +}; +use itertools::Itertools; +use plonk_verifier::{ + loader::{self, native::NativeLoader, Loader, ScalarLoader}, + pcs::{ + kzg::{Gwc19, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, + AccumulationScheme, AccumulationSchemeProver, + }, + system::halo2::{self, compile, Config}, + util::{ + arithmetic::{fe_to_fe, fe_to_limbs}, + hash, + }, + verifier::{self, PlonkProof, PlonkVerifier}, + Protocol, +}; +use rand_chacha::{ + rand_core::{OsRng, SeedableRng}, + ChaCha20Rng, +}; +use std::{fs, iter, marker::PhantomData, rc::Rc}; + +const LIMBS: usize = 4; +const BITS: usize = 68; +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +type Pcs = Kzg; +type Svk = KzgSuccinctVerifyingKey; +type As = KzgAs; +type Plonk = verifier::Plonk>; +type Poseidon = hash::Poseidon; +type PoseidonTranscript = + halo2::transcript::halo2::PoseidonTranscript; + +mod common { + use super::*; + use halo2_proofs::poly::commitment::Params; + use plonk_verifier::{cost::CostEstimation, util::transcript::TranscriptWrite}; + + pub fn poseidon>( + loader: &L, + inputs: &[L::LoadedScalar], + ) -> L::LoadedScalar { + let mut hasher = Poseidon::new(loader, R_F, R_P); + hasher.update(inputs); + hasher.squeeze() + } + + pub struct Snark { + pub protocol: Protocol, + pub instances: Vec>, + pub proof: Vec, + } + + impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + Self { + protocol, + instances, + proof, + } + } + } + + impl From for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } + } + + #[derive(Clone)] + pub struct SnarkWitness { + pub protocol: Protocol, + pub instances: Vec>>, + pub proof: Value>, + } + + impl SnarkWitness { + pub fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + pub fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } + } + + pub trait CircuitExt: Circuit { + fn num_instance() -> Vec; + + fn instances(&self) -> Vec>; + + fn accumulator_indices() -> Option> { + None + } + } + + pub fn gen_srs(k: u32) -> ParamsKZG { + let path = format!("./examples/k-{}.srs", k); + match fs::File::open(path.as_str()) { + Ok(mut file) => ParamsKZG::read(&mut file).unwrap(), + Err(_) => { + let params = + ParamsKZG::::setup(k, ChaCha20Rng::from_seed(Default::default())); + params.write(&mut fs::File::create(path).unwrap()).unwrap(); + params + } + } + } + + pub fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { + let vk = keygen_vk(params, circuit).unwrap(); + keygen_pk(params, vk, circuit).unwrap() + } + + pub fn gen_proof>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + ) -> Vec { + if params.k() > 3 { + MockProver::run(params.k(), &circuit, instances.clone()) + .unwrap() + .assert_satisfied(); + } + + let instances = instances.iter().map(Vec::as_slice).collect_vec(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + create_proof::<_, ProverGWC<_>, _, _, _, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + OsRng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = PoseidonTranscript::::new(proof.as_slice()); + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + verify_proof::<_, VerifierGWC<_>, _, _, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof + } + + pub fn gen_snark>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + ) -> Snark { + let protocol = compile( + params, + pk.get_vk(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + + let instances = circuit.instances(); + let proof = gen_proof(params, pk, circuit, instances.clone()); + + Snark::new(protocol, instances, proof) + } + + pub fn gen_dummy_snark>( + params: &ParamsKZG, + vk: Option<&VerifyingKey>, + ) -> Snark { + struct CsProxy(PhantomData<(F, C)>); + + impl> Circuit for CsProxy { + type Config = C::Config; + type FloorPlanner = C::FloorPlanner; + + fn without_witnesses(&self) -> Self { + CsProxy(PhantomData) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + C::configure(meta) + } + + fn synthesize(&self, _: Self::Config, _: impl Layouter) -> Result<(), Error> { + Ok(()) + } + } + + let dummy_vk = vk + .is_none() + .then(|| keygen_vk(params, &CsProxy::(PhantomData)).unwrap()); + let protocol = compile( + params, + vk.or(dummy_vk.as_ref()).unwrap(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + let instances = ConcreteCircuit::num_instance() + .into_iter() + .map(|n| iter::repeat_with(|| Fr::random(OsRng)).take(n).collect()) + .collect(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + for _ in 0..protocol + .num_witness + .iter() + .chain(Some(&protocol.quotient.num_chunk())) + .sum::() + { + transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); + } + for _ in 0..protocol.evaluations.len() { + transcript.write_scalar(Fr::random(OsRng)).unwrap(); + } + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..Pcs::estimate_cost(&queries).num_commitment { + transcript.write_ec_point(G1Affine::random(OsRng)).unwrap(); + } + transcript.finalize() + }; + + Snark::new(protocol, instances, proof) + } +} + +mod application { + use super::*; + + #[derive(Clone, Default)] + pub struct Square(Fr); + + impl Circuit for Square { + type Config = Selector; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let q = meta.selector(); + let i = meta.instance_column(); + meta.create_gate("square", |meta| { + let q = meta.query_selector(q); + let [i, i_w] = [0, 1].map(|rotation| meta.query_instance(i, Rotation(rotation))); + Some(q * (i.clone() * i - i_w)) + }); + q + } + + fn synthesize( + &self, + q: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region(|| "", |mut region| q.enable(&mut region, 0)) + } + } + + impl CircuitExt for Square { + fn num_instance() -> Vec { + vec![2] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0, self.0.square()]] + } + } + + impl recursion::StateTransition for Square { + type Input = (); + + fn new(state: Fr) -> Self { + Self(state) + } + + fn state_transition(&self, _: Self::Input) -> Fr { + self.0.square() + } + } +} + +mod recursion { + use super::*; + use halo2_wrong_ecc::{ + integer::rns::Rns, + maingate::{ + MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, + RangeInstructions, RegionCtx, + }, + EccConfig, + }; + + type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; + type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + + pub trait StateTransition { + type Input; + + fn new(state: Fr) -> Self; + + fn state_transition(&self, input: Self::Input) -> Fr; + } + + fn succinct_verify<'a>( + svk: &Svk, + loader: &Rc>, + snark: &SnarkWitness, + preprocessed_digest: Option>, + ) -> ( + Vec>>, + Vec>>>, + ) { + let protocol = if let Some(preprocessed_digest) = preprocessed_digest { + let preprocessed_digest = loader.scalar_from_assigned(preprocessed_digest); + let protocol = snark.protocol.loaded_preprocessed_as_witness(loader); + let inputs = protocol + .preprocessed + .iter() + .flat_map(|preprocessed| { + let assigned = preprocessed.assigned(); + [assigned.x(), assigned.y()] + .map(|coordinate| loader.scalar_from_assigned(coordinate.native().clone())) + }) + .chain(protocol.transcript_initial_state.clone()) + .collect_vec(); + loader + .assert_eq("", &poseidon(loader, &inputs), &preprocessed_digest) + .unwrap(); + protocol + } else { + snark.protocol.loaded(loader) + }; + + let instances = snark + .instances + .iter() + .map(|instances| { + instances + .iter() + .map(|instance| loader.assign_scalar(*instance)) + .collect_vec() + }) + .collect_vec(); + let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + let accumulators = Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap(); + + ( + instances + .into_iter() + .map(|instance| { + instance + .into_iter() + .map(|instance| instance.into_assigned()) + .collect() + }) + .collect(), + accumulators, + ) + } + + fn select_accumulator<'a>( + loader: &Rc>, + condition: &AssignedCell, + lhs: &KzgAccumulator>>, + rhs: &KzgAccumulator>>, + ) -> Result>>, Error> { + let [lhs, rhs]: [_; 2] = [lhs.lhs.assigned(), lhs.rhs.assigned()] + .iter() + .zip([rhs.lhs.assigned(), rhs.rhs.assigned()].iter()) + .map(|(lhs, rhs)| { + let selected = + loader + .ecc_chip() + .select(&mut loader.ctx_mut(), condition, lhs, rhs)?; + Ok(loader.ec_point_from_assigned(selected)) + }) + .collect::, Error>>()? + .try_into() + .unwrap(); + Ok(KzgAccumulator::new(lhs, rhs)) + } + + fn accumulate<'a>( + loader: &Rc>, + accumulators: Vec>>>, + as_proof: Value<&'_ [u8]>, + ) -> KzgAccumulator>> { + let mut transcript = PoseidonTranscript::, _>::new(loader, as_proof); + let proof = As::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + As::verify(&Default::default(), &accumulators, &proof).unwrap() + } + + #[derive(Clone)] + pub struct RecursionConfig { + main_gate_config: MainGateConfig, + range_config: RangeConfig, + } + + impl RecursionConfig { + pub fn main_gate(&self) -> MainGate { + MainGate::new(self.main_gate_config.clone()) + } + + pub fn range_chip(&self) -> RangeChip { + RangeChip::new(self.range_config.clone()) + } + + pub fn ecc_chip(&self) -> BaseFieldEccChip { + BaseFieldEccChip::new(EccConfig::new( + self.range_config.clone(), + self.main_gate_config.clone(), + )) + } + } + + #[derive(Clone)] + pub struct RecursionCircuit { + svk: Svk, + default_accumulator: KzgAccumulator, + app: SnarkWitness, + previous: SnarkWitness, + round: usize, + instances: Vec, + as_proof: Value>, + } + + impl RecursionCircuit { + const PREPROCESSED_DIGEST_ROW: usize = 4 * LIMBS; + const INITIAL_STATE_ROW: usize = 4 * LIMBS + 1; + const STATE_ROW: usize = 4 * LIMBS + 2; + const ROUND_ROW: usize = 4 * LIMBS + 3; + + pub fn new( + params: &ParamsKZG, + app: Snark, + previous: Snark, + initial_state: Fr, + state: Fr, + round: usize, + ) -> Self { + let svk = params.get_g()[0].into(); + let default_accumulator = KzgAccumulator::new(params.get_g()[1], params.get_g()[0]); + + let succinct_verify = |snark: &Snark| { + let mut transcript = + PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript) + .unwrap(); + Plonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }; + + let accumulators = iter::empty() + .chain(succinct_verify(&app)) + .chain( + (round > 0) + .then(|| succinct_verify(&previous)) + .unwrap_or_else(|| { + let num_accumulator = 1 + previous.protocol.accumulator_indices.len(); + vec![default_accumulator.clone(); num_accumulator] + }), + ) + .collect_vec(); + + let (accumulator, as_proof) = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + let accumulator = + As::create_proof(&Default::default(), &accumulators, &mut transcript, OsRng) + .unwrap(); + (accumulator, transcript.finalize()) + }; + + let preprocessed_digest = { + let inputs = previous + .protocol + .preprocessed + .iter() + .flat_map(|preprocessed| [preprocessed.x, preprocessed.y]) + .map(fe_to_fe) + .chain(previous.protocol.transcript_initial_state) + .collect_vec(); + poseidon(&NativeLoader, &inputs) + }; + let instances = [ + accumulator.lhs.x, + accumulator.lhs.y, + accumulator.rhs.x, + accumulator.rhs.y, + ] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([ + preprocessed_digest, + initial_state, + state, + Fr::from(round as u64), + ]) + .collect(); + + Self { + svk, + default_accumulator, + app: app.into(), + previous: previous.into(), + round, + instances, + as_proof: Value::known(as_proof), + } + } + + fn initial_snark(params: &ParamsKZG, vk: Option<&VerifyingKey>) -> Snark { + let mut snark = gen_dummy_snark::(params, vk); + let g = params.get_g(); + snark.instances = vec![[g[1].x, g[1].y, g[0].x, g[0].y] + .into_iter() + .flat_map(fe_to_limbs::<_, _, LIMBS, BITS>) + .chain([Fr::zero(); 4]) + .collect_vec()]; + snark + } + + fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } + + fn load_default_accumulator<'a>( + &self, + loader: &Rc>, + ) -> Result>>, Error> { + let [lhs, rhs] = + [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { + let assigned = loader + .ecc_chip() + .assign_constant(&mut loader.ctx_mut(), default) + .unwrap(); + loader.ec_point_from_assigned(assigned) + }); + Ok(KzgAccumulator::new(lhs, rhs)) + } + } + + impl Circuit for RecursionCircuit { + type Config = RecursionConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + default_accumulator: self.default_accumulator.clone(), + app: self.app.without_witnesses(), + previous: self.previous.without_witnesses(), + round: self.round, + instances: self.instances.clone(), + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let main_gate_config = MainGate::::configure(meta); + let range_config = RangeChip::::configure( + meta, + &main_gate_config, + vec![BITS / LIMBS], + Rns::::construct().overflow_lengths(), + ); + RecursionConfig { + main_gate_config, + range_config, + } + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let main_gate = config.main_gate(); + let range_chip = config.range_chip(); + + range_chip.load_table(&mut layouter)?; + + let [preprocessed_digest, initial_state, state, round, first_round, not_first_round] = + layouter.assign_region( + || "", + |region| { + let mut ctx = RegionCtx::new(region, 0); + let [preprocessed_digest, initial_state, state, round] = [ + self.instances[Self::PREPROCESSED_DIGEST_ROW], + self.instances[Self::INITIAL_STATE_ROW], + self.instances[Self::STATE_ROW], + self.instances[Self::ROUND_ROW], + ] + .map(|instance| { + main_gate + .assign_value(&mut ctx, Value::known(instance)) + .unwrap() + }); + let first_round = main_gate.is_zero(&mut ctx, &round)?; + let not_first_round = main_gate.not(&mut ctx, &first_round)?; + Ok([ + preprocessed_digest, + initial_state, + state, + round, + first_round, + not_first_round, + ]) + }, + )?; + + let (lhs, rhs, app_instances, previous_instances) = layouter.assign_region( + || "", + |region| { + let loader = Halo2Loader::new(config.ecc_chip(), RegionCtx::new(region, 0)); + let (mut app_instances, app_accumulators) = + succinct_verify(&self.svk, &loader, &self.app, None); + let (mut previous_instances, previous_accumulators) = succinct_verify( + &self.svk, + &loader, + &self.previous, + Some(preprocessed_digest.clone()), + ); + + let default_accmulator = self.load_default_accumulator(&loader)?; + let previous_accumulators = previous_accumulators + .iter() + .map(|previous_accumulator| { + select_accumulator( + &loader, + &first_round, + &default_accmulator, + previous_accumulator, + ) + }) + .collect::, Error>>()?; + + let KzgAccumulator { lhs, rhs } = accumulate( + &loader, + [app_accumulators, previous_accumulators].concat(), + self.as_proof(), + ); + + Ok(( + lhs.into_assigned(), + rhs.into_assigned(), + app_instances.pop().unwrap(), + previous_instances.pop().unwrap(), + )) + }, + )?; + + layouter.assign_region( + || "", + |region| { + let mut ctx = RegionCtx::new(region, 0); + for (lhs, rhs) in [ + // Propagate preprocessed_digest + ( + &main_gate.mul(&mut ctx, &preprocessed_digest, ¬_first_round)?, + &previous_instances[Self::PREPROCESSED_DIGEST_ROW], + ), + // Propagate initial_state + ( + &main_gate.mul(&mut ctx, &initial_state, ¬_first_round)?, + &previous_instances[Self::INITIAL_STATE_ROW], + ), + // Verify initial_state is same as the first application snark + ( + &main_gate.mul(&mut ctx, &initial_state, &first_round)?, + &main_gate.mul(&mut ctx, &app_instances[0], &first_round)?, + ), + // Verify current state is same as the current application snark + (&state, &app_instances[1]), + // Verify previous state is same as the current application snark + ( + &main_gate.mul(&mut ctx, &app_instances[0], ¬_first_round)?, + &previous_instances[Self::STATE_ROW], + ), + // Verify round is increased by 1 when not at first round + ( + &round, + &main_gate.add( + &mut ctx, + ¬_first_round, + &previous_instances[Self::ROUND_ROW], + )?, + ), + ] { + main_gate.assert_equal(&mut ctx, lhs, rhs)?; + } + Ok(()) + }, + )?; + + for (row, limb) in [lhs.x(), lhs.y(), rhs.x(), rhs.y()] + .into_iter() + .flat_map(|coordinate| coordinate.limbs()) + .map_into() + .chain([preprocessed_digest, initial_state, state, round]) + .enumerate() + { + main_gate.expose_public(layouter.namespace(|| ""), limb, row)?; + } + + Ok(()) + } + } + + impl CircuitExt for RecursionCircuit { + fn num_instance() -> Vec { + // [..lhs, ..rhs, preprocessed_digest, initial_state, state, round] + vec![4 * LIMBS + 4] + } + + fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + fn accumulator_indices() -> Option> { + Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) + } + } + + pub fn gen_recursion_pk>( + recursion_params: &ParamsKZG, + app_params: &ParamsKZG, + app_vk: &VerifyingKey, + ) -> ProvingKey { + let recursion = RecursionCircuit::new( + recursion_params, + gen_dummy_snark::(app_params, Some(app_vk)), + RecursionCircuit::initial_snark(recursion_params, None), + Fr::zero(), + Fr::zero(), + 0, + ); + gen_pk(recursion_params, &recursion) + } + + pub fn gen_recursion_snark + StateTransition>( + app_params: &ParamsKZG, + recursion_params: &ParamsKZG, + app_pk: &ProvingKey, + recursion_pk: &ProvingKey, + initial_state: Fr, + inputs: Vec, + ) -> (Fr, Snark) { + let mut state = initial_state; + let mut app = ConcreteCircuit::new(state); + let mut previous = + RecursionCircuit::initial_snark(recursion_params, Some(recursion_pk.get_vk())); + for (round, input) in inputs.into_iter().enumerate() { + state = app.state_transition(input); + let recursion = RecursionCircuit::new( + recursion_params, + gen_snark(app_params, app_pk, app), + previous, + initial_state, + state, + round, + ); + previous = gen_snark(recursion_params, recursion_pk, recursion); + app = ConcreteCircuit::new(state); + } + (state, previous) + } +} + +fn main() { + let app_params = gen_srs(3); + let recursion_params = gen_srs(22); + + let app_pk = gen_pk(&app_params, &application::Square::default()); + let recursion_pk = recursion::gen_recursion_pk::( + &recursion_params, + &app_params, + app_pk.get_vk(), + ); + + let num_round = 3; + let (final_state, snark) = recursion::gen_recursion_snark::( + &app_params, + &recursion_params, + &app_pk, + &recursion_pk, + Fr::from(2), + vec![(); num_round], + ); + assert_eq!(final_state, Fr::from(2).pow(&[1 << num_round, 0, 0, 0])); + + let accept = { + let svk = recursion_params.get_g()[0].into(); + let dk = (recursion_params.g2(), recursion_params.s_g2()).into(); + let mut transcript = PoseidonTranscript::::new(snark.proof.as_slice()); + let proof = + Plonk::read_proof(&svk, &snark.protocol, &snark.instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &snark.protocol, &snark.instances, &proof).unwrap() + }; + assert!(accept) +} diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index 67c2b12d..b49b3694 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -319,6 +319,13 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> &self.loader } + pub fn into_assigned(self) -> EccChip::AssignedScalar { + match self.value.into_inner() { + Value::Constant(constant) => self.loader.assign_const_scalar(constant), + Value::Assigned(assigned) => assigned, + } + } + pub fn assigned(&self) -> Ref { if let Some(constant) = self.maybe_const() { *self.value.borrow_mut() = Value::Assigned(self.loader.assign_const_scalar(constant)) From a50d7a1a2e0034c5c6f69ca599fcd661259a15ee Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Mon, 21 Nov 2022 18:11:12 -0500 Subject: [PATCH 20/28] wip: update to halo2-lib v0.2.0 --- Cargo.toml | 52 ++++++++-- examples/evm-verifier-with-accumulator.rs | 9 +- examples/recursion.rs | 112 ++++++++++++++-------- rust-toolchain | 2 +- src/loader/evm/loader.rs | 8 +- src/loader/halo2/shim.rs | 106 ++++++++++---------- src/pcs/kzg/accumulator.rs | 9 +- src/system/halo2.rs | 12 +-- src/system/halo2/transcript/halo2.rs | 8 +- src/util/arithmetic.rs | 6 +- src/verifier/plonk.rs | 7 +- 11 files changed, 200 insertions(+), 131 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ae646f4e..482748d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,13 @@ name = "plonk_verifier" version = "0.1.0" edition = "2021" +[patch."https://github.com/privacy-scaling-explorations/halo2curves"] +halo2curves = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/faster-witness-generation" } + +# [patch."ssh://github.com/axiom-crypto/halo2-lib-working.git"] +# halo2_base = { path = "../halo2-lib-working/halo2_base" } +# halo2_ecc = { path = "../halo2-lib-working/halo2_ecc" } + [dependencies] itertools = "0.10.3" lazy_static = "1.4.0" @@ -12,18 +19,25 @@ num-traits = "0.2.15" rand = "0.8" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } + +halo2_curves = { git = "https://github.com/axiom-crypto/halo2", branch = "axiom/faster-witness-generation", package = "halo2curves" } + +# parallel +rayon = { version = "1.5.3", optional = true } # system_halo2 halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2", tag = "v2022_10_22", optional = true } # loader_evm -ethereum_types = { package = "ethereum-types", version = "0.13.1", default-features = false, features = ["std"], optional = true } -sha3 = { version = "0.10.1", optional = true } +ethereum_types = { package = "ethereum-types", version = "0.13", default-features = false, features = ["std"], optional = true } +sha3 = { version = "0.10", optional = true } +revm = { version = "2.1.0", optional = true } +bytes = { version = "1.2", optional = true } +rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } # loader_halo2 -halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_base", default-features = false, optional = true } -halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_ecc", default-features = false, optional = true } +halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.0", package = "halo2_base", default-features = false, optional = true } +halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.0", package = "halo2_ecc", default-features = false, optional = true } halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc", optional = true } poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true } @@ -32,18 +46,19 @@ rand_chacha = "0.3.1" paste = "1.0.7" # system_halo2 -halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", package = "halo2_ecc", default-features = false } +halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.0", package = "halo2_ecc", default-features = false } halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc" } # loader_evm -foundry_evm = { git = "https://github.com/foundry-rs/foundry", package = "foundry-evm", rev = "6b1ee60e" } crossterm = { version = "0.22.1" } tui = { version = "0.16.0", default-features = false, features = ["crossterm"] } [features] default = ["loader_evm", "loader_halo2", "system_halo2"] -loader_evm = ["dep:ethereum_types", "dep:sha3"] +parallel = ["dep:rayon"] + +loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] loader_halo2 = ["dep:halo2_proofs", "dep:halo2_base", "halo2_ecc", "dep:halo2_wrong_ecc", "dep:poseidon"] system_halo2 = ["dep:halo2_proofs"] @@ -61,3 +76,24 @@ required-features = ["loader_halo2", "loader_evm", "system_halo2"] [[example]] name = "recursion" required-features = ["loader_halo2", "system_halo2"] + +[profile.dev] +opt-level = 3 + +# Local "release" mode, more optimized than dev but much faster to compile than release +[profile.local] +inherits = "dev" +opt-level = 3 +# Set this to 1 or 2 to get more useful backtraces +debug = 0 +panic = 'unwind' +# better recompile times +incremental = true +codegen-units = 16 + +[profile.release] +debug = 0 +opt-level = 3 +lto = "fat" +codegen-unit = 1 +panic = "abort" \ No newline at end of file diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index ee8c1526..2dc38f8d 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -340,7 +340,8 @@ mod aggregation { params.limb_bits, params.num_limbs, halo2_base::utils::modulus::(), - "verifier".to_string(), + 0, + params.degree, ); let instance = meta.instance_column(); @@ -472,10 +473,8 @@ mod aggregation { let ctx = Context::new( region, ContextParams { - num_advice: vec![( - config.base_field_config.range.context_id.clone(), - config.base_field_config.range.gate.num_advice, - )], + max_rows: config.range().gate.max_rows, + num_advice: vec![config.base_field_config.range.gate.num_advice], fixed_columns: config.base_field_config.range.gate.constants.clone(), }, ); diff --git a/examples/recursion.rs b/examples/recursion.rs index 5c34d663..e7ac8f04 100644 --- a/examples/recursion.rs +++ b/examples/recursion.rs @@ -45,12 +45,12 @@ use rand::{rngs::OsRng, SeedableRng}; use rand_chacha::ChaCha20Rng; use std::{fs, iter, marker::PhantomData, rc::Rc}; -const LIMBS: usize = 4; -const BITS: usize = 68; -const T: usize = 5; -const RATE: usize = 4; +const LIMBS: usize = 3; +const BITS: usize = 88; +const T: usize = 3; +const RATE: usize = 2; const R_F: usize = 8; -const R_P: usize = 60; +const R_P: usize = 57; type Pcs = Kzg; type Svk = KzgSuccinctVerifyingKey; @@ -341,18 +341,13 @@ mod application { } mod recursion { + use halo2_base::AssignedValue; + use halo2_ecc::ecc::EccChip; + use super::*; - use halo2_wrong_ecc::{ - integer::rns::Rns, - maingate::{ - MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, - RangeInstructions, RegionCtx, - }, - EccConfig, - }; - type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; - type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; + type BaseFieldEccChip<'b> = halo2_ecc::ecc::BaseFieldEccChip<'b, G1Affine>; + type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip<'a>>; pub trait StateTransition { type Input; @@ -366,9 +361,9 @@ mod recursion { svk: &Svk, loader: &Rc>, snark: &SnarkWitness, - expected_preprocessed_digest: Option>, + expected_preprocessed_digest: Option>, ) -> ( - Vec>>, + Vec>>, Vec>>>, ) { let protocol = if let Some(expected_preprocessed_digest) = expected_preprocessed_digest { @@ -428,7 +423,7 @@ mod recursion { fn select_accumulator<'a>( loader: &Rc>, - condition: &AssignedCell, + condition: &AssignedValue, lhs: &KzgAccumulator>>, rhs: &KzgAccumulator>>, ) -> Result>>, Error> { @@ -438,7 +433,7 @@ mod recursion { .map(|(lhs, rhs)| { loader .ecc_chip() - .select(&mut loader.ctx_mut(), condition, lhs, rhs) + .select(&mut loader.ctx_mut(), lhs, rhs, condition) }) .collect::, _>>()? .try_into() @@ -459,26 +454,64 @@ mod recursion { As::verify(&Default::default(), &accumulators, &proof).unwrap() } + #[derive(serde::Serialize, serde::Deserialize)] + pub struct AggregationConfigParams { + pub strategy: halo2_ecc::fields::fp::FpStrategy, + pub degree: u32, + pub num_advice: usize, + pub num_lookup_advice: usize, + pub num_fixed: usize, + pub lookup_bits: usize, + pub limb_bits: usize, + pub num_limbs: usize, + } + #[derive(Clone)] pub struct RecursionConfig { - main_gate_config: MainGateConfig, - range_config: RangeConfig, + pub base_field_config: halo2_ecc::fields::fp::FpConfig, + pub instance: Column, } impl RecursionConfig { - pub fn main_gate(&self) -> MainGate { - MainGate::new(self.main_gate_config.clone()) + pub fn configure(meta: &mut ConstraintSystem, params: AggregationConfigParams) -> Self { + assert!( + params.limb_bits == BITS && params.num_limbs == LIMBS, + "For now we fix limb_bits = {}, otherwise change code", + BITS + ); + let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( + meta, + params.strategy, + &[params.num_advice], + &[params.num_lookup_advice], + params.num_fixed, + params.lookup_bits, + params.limb_bits, + params.num_limbs, + halo2_base::utils::modulus::(), + 0, + params.degree, + ); + + let instance = meta.instance_column(); + meta.enable_equality(instance); + + Self { + base_field_config, + instance, + } } - pub fn range_chip(&self) -> RangeChip { - RangeChip::new(self.range_config.clone()) + pub fn gate(&self) -> &halo2_base::gates::flex_gate::FlexGateConfig { + &self.base_field_config.range.gate } - pub fn ecc_chip(&self) -> BaseFieldEccChip { - BaseFieldEccChip::new(EccConfig::new( - self.range_config.clone(), - self.main_gate_config.clone(), - )) + pub fn range(&self) -> &halo2_base::gates::range::RangeConfig { + &self.base_field_config.range + } + + pub fn ecc_chip(&self) -> halo2_ecc::ecc::BaseFieldEccChip<'_, G1Affine> { + EccChip::construct(&self.base_field_config) } } @@ -629,17 +662,14 @@ mod recursion { } fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - let main_gate_config = MainGate::::configure(meta); - let range_config = RangeChip::::configure( - meta, - &main_gate_config, - vec![BITS / LIMBS], - Rns::::construct().overflow_lengths(), - ); - RecursionConfig { - main_gate_config, - range_config, - } + let path = std::env::var("VERIFY_CONFIG") + .unwrap_or_else(|| "configs/verify_circuit.config".to_owned()); + let params: AggregationConfigParams = serde_json::from_reader( + File::open(path.as_str()).expect(format!("{path} file should exist").as_str()), + ) + .unwrap(); + + RecursionConfig::configure(meta, params) } fn synthesize( diff --git a/rust-toolchain b/rust-toolchain index 7cc6ef41..51ab4759 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.63.0 \ No newline at end of file +nightly-2022-10-28 \ No newline at end of file diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs index 7d1dbb94..d9d7ca5a 100644 --- a/src/loader/evm/loader.rs +++ b/src/loader/evm/loader.rs @@ -38,10 +38,10 @@ impl PartialEq for Value { impl Value { fn identifier(&self) -> String { match &self { - Value::Constant(_) | Value::Memory(_) => format!("{:?}", self), - Value::Negated(value) => format!("-({:?})", value), - Value::Sum(lhs, rhs) => format!("({:?} + {:?})", lhs, rhs), - Value::Product(lhs, rhs) => format!("({:?} * {:?})", lhs, rhs), + Value::Constant(_) | Value::Memory(_) => format!("{self:?}"), + Value::Negated(value) => format!("-({value:?})"), + Value::Sum(lhs, rhs) => format!("({lhs:?} + {rhs:?})"), + Value::Product(lhs, rhs) => format!("({lhs:?} * {rhs:?})"), } } } diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 3bbac5a1..8aaaaf45 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -134,18 +134,18 @@ pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { mod halo2_lib { use crate::{ loader::halo2::{Context, EccInstructions, IntegerInstructions}, - util::arithmetic::{CurveAffine, Field, FieldExt, PrimeField}, + util::arithmetic::{CurveAffine, Field, PrimeField}, }; use halo2_base::{ self, gates::{flex_gate::FlexGateConfig, GateInstructions, RangeInstructions}, AssignedValue, - QuantumCell::{Constant, Existing}, + QuantumCell::{Constant, Existing, Witness}, }; - use halo2_curves::group::prime::PrimeCurveAffine; + use halo2_curves::BigPrimeField; use halo2_ecc::{ bigint::CRTInteger, - ecc::{fixed::FixedEccPoint, BaseFieldEccChip, EccPoint}, + ecc::{fixed::FixedEcPoint, BaseFieldEccChip, EcPoint}, fields::FieldChip, }; use halo2_proofs::{ @@ -153,24 +153,23 @@ mod halo2_lib { plonk::Error, }; - type AssignedInteger = CRTInteger<::ScalarExt>; - type AssignedEcPoint = EccPoint<::ScalarExt, AssignedInteger>; + type AssignedInteger<'v, C> = CRTInteger<'v, ::ScalarExt>; + type AssignedEcPoint<'v, C> = EcPoint<::ScalarExt, AssignedInteger<'v, C>>; - impl<'a, F: FieldExt> Context for halo2_base::Context<'a, F> { + impl<'a, F: BigPrimeField> Context for halo2_base::Context<'a, F> { fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { self.region.constrain_equal(lhs, rhs) } fn offset(&self) -> usize { - dbg!("using context offset"); - *self.advice_rows.values().flatten().max().unwrap() + unreachable!() } } - impl<'a, F: FieldExt> IntegerInstructions<'a, F> for FlexGateConfig { + impl<'a, F: BigPrimeField> IntegerInstructions<'a, F> for FlexGateConfig { type Context = halo2_base::Context<'a, F>; type Integer = F; - type AssignedInteger = AssignedValue; + type AssignedInteger = AssignedValue<'a, F>; fn integer(&self, scalar: F) -> Self::Integer { scalar @@ -181,7 +180,7 @@ mod halo2_lib { ctx: &mut Self::Context, integer: Value, ) -> Result { - Ok(self.assign_witnesses(ctx, vec![integer])?.pop().unwrap()) + Ok(self.assign_region_last(ctx, vec![Witness(integer)], vec![])) } fn assign_constant( @@ -189,10 +188,7 @@ mod halo2_lib { ctx: &mut Self::Context, integer: F, ) -> Result { - Ok(self - .assign_region(ctx, vec![Constant(integer)], vec![], None)? - .pop() - .unwrap()) + Ok(self.assign_region_last(ctx, vec![Constant(integer)], vec![])) } fn sum_with_coeff_and_const( @@ -209,8 +205,7 @@ mod halo2_lib { } a.extend(values.iter().map(|(_, a)| Existing(a))); b.extend(values.iter().map(|(c, _)| Constant(*c))); - let (_, _, sum) = self.inner_product(ctx, a, b)?; - Ok(sum) + Ok(self.inner_product(ctx, a, b)) } fn sum_products_with_coeff_and_const( @@ -221,15 +216,13 @@ mod halo2_lib { ) -> Result { match values.len() { 0 => self.assign_constant(ctx, constant), - _ => { - let mut prods = Vec::with_capacity(values.len()); - for (c, a, b) in values.into_iter() { - let a = Existing(&a); - let b = Existing(&b); - prods.push((*c, a, b)); - } - self.sum_products_with_coeff_and_var(ctx, &prods, &Constant(constant)) - } + _ => Ok(self.sum_products_with_coeff_and_var( + ctx, + values + .iter() + .map(|(c, a, b)| (*c, Existing(a), Existing(b))), + Constant(constant), + )), } } @@ -239,7 +232,7 @@ mod halo2_lib { a: &Self::AssignedInteger, b: &Self::AssignedInteger, ) -> Result { - GateInstructions::sub(self, ctx, &Existing(a), &Existing(b)) + Ok(GateInstructions::sub(self, ctx, Existing(a), Existing(b))) } fn neg( @@ -247,7 +240,7 @@ mod halo2_lib { ctx: &mut Self::Context, a: &Self::AssignedInteger, ) -> Result { - GateInstructions::neg(self, ctx, &Existing(a)) + Ok(GateInstructions::neg(self, ctx, Existing(a))) } fn invert( @@ -256,9 +249,14 @@ mod halo2_lib { a: &Self::AssignedInteger, ) -> Result { // make sure scalar != 0 - let is_zero = self.is_zero(ctx, a)?; - self.assert_is_const(ctx, &is_zero, F::zero())?; - GateInstructions::div_unsafe(self, ctx, &Constant(F::one()), &Existing(a)) + let is_zero = self.is_zero(ctx, a); + self.assert_is_const(ctx, &is_zero, F::zero()); + Ok(GateInstructions::div_unsafe( + self, + ctx, + Constant(F::one()), + Existing(a), + )) } fn assert_equal( @@ -271,12 +269,16 @@ mod halo2_lib { } } - impl<'a, 'b, C: CurveAffine> EccInstructions<'a, C> for BaseFieldEccChip<'b, C> { + impl<'a, 'b, C: CurveAffine> EccInstructions<'a, C> for BaseFieldEccChip<'b, C> + where + C::Scalar: BigPrimeField, + C::Base: BigPrimeField, + { type Context = halo2_base::Context<'a, C::Scalar>; type ScalarChip = FlexGateConfig; - type AssignedEcPoint = AssignedEcPoint; + type AssignedEcPoint = AssignedEcPoint<'a, C>; type Scalar = C::Scalar; - type AssignedScalar = AssignedValue; + type AssignedScalar = AssignedValue<'a, C::Scalar>; fn scalar_chip(&self) -> &Self::ScalarChip { self.field_chip.range().gate() @@ -287,12 +289,12 @@ mod halo2_lib { ctx: &mut Self::Context, point: C, ) -> Result { - let fixed = FixedEccPoint::::from_g1( + let fixed = FixedEcPoint::::from_g1( &point, self.field_chip.num_limbs, self.field_chip.limb_bits, ); - FixedEccPoint::assign(fixed, self.field_chip, ctx) + Ok(FixedEcPoint::assign(fixed, self.field_chip, ctx)) } fn assign_point( @@ -300,13 +302,13 @@ mod halo2_lib { ctx: &mut Self::Context, point: Value, ) -> Result { - let assigned = self.assign_point(ctx, point)?; - let is_on_curve_or_infinity = self.is_on_curve_or_infinity::(ctx, &assigned)?; + let assigned = self.assign_point(ctx, point); + let is_on_curve_or_infinity = self.is_on_curve_or_infinity::(ctx, &assigned); self.field_chip.range.gate.assert_is_const( ctx, &is_on_curve_or_infinity, C::Scalar::one(), - )?; + ); Ok(assigned) } @@ -316,12 +318,13 @@ mod halo2_lib { values: &[Self::AssignedEcPoint], constant: C, ) -> Result { - if bool::from(constant.is_identity()) { - self.sum::(ctx, values.iter()) + let constant = if bool::from(constant.is_identity()) { + None } else { - let constant = EccInstructions::::assign_constant(self, ctx, constant)?; - self.sum::(ctx, values.iter().chain([constant].iter()).into_iter()) - } + let constant = EccInstructions::::assign_constant(self, ctx, constant).unwrap(); + Some(constant) + }; + Ok(self.sum::(ctx, constant.iter().chain(values.iter()))) } fn variable_base_msm( @@ -330,13 +333,13 @@ mod halo2_lib { pairs: &[(Self::AssignedScalar, Self::AssignedEcPoint)], ) -> Result { let (scalars, points): (Vec<_>, Vec<_>) = pairs.iter().cloned().unzip(); - self.multi_scalar_mult::( + Ok(self.multi_scalar_mult::( ctx, &points, &scalars.into_iter().map(|scalar| vec![scalar]).collect(), ::NUM_BITS as usize, 4, // empirically clump factor of 4 seems to be best - ) + )) } fn fixed_base_msm( @@ -345,15 +348,15 @@ mod halo2_lib { pairs: &[(Self::AssignedScalar, C)], ) -> Result { let (scalars, points): (Vec<_>, Vec<_>) = pairs.iter().cloned().unzip(); - BaseFieldEccChip::::fixed_base_msm::( - &self, + Ok(BaseFieldEccChip::::fixed_base_msm::( + self, ctx, &points, &scalars.into_iter().map(|scalar| vec![scalar]).collect(), ::NUM_BITS as usize, 0, 4, - ) + )) } fn normalize( @@ -370,7 +373,8 @@ mod halo2_lib { a: &Self::AssignedEcPoint, b: &Self::AssignedEcPoint, ) -> Result<(), Error> { - self.assert_equal(ctx, a, b) + self.assert_equal(ctx, a, b); + Ok(()) } } } diff --git a/src/pcs/kzg/accumulator.rs b/src/pcs/kzg/accumulator.rs index 034e231c..27a64330 100644 --- a/src/pcs/kzg/accumulator.rs +++ b/src/pcs/kzg/accumulator.rs @@ -222,13 +222,16 @@ mod halo2 { mod halo2_lib { use super::*; use halo2_base::AssignedValue; - use halo2_ecc::{bigint::CRTInteger, ecc::EccPoint}; + use halo2_curves::BigPrimeField; + use halo2_ecc::{bigint::CRTInteger, ecc::EcPoint}; impl<'a, C, PCS, EccChip, const LIMBS: usize, const BITS: usize> AccumulatorEncoding>, PCS> for LimbsEncoding where C: CurveAffine, + C::Scalar: BigPrimeField, + C::Base: BigPrimeField, PCS: PolynomialCommitmentScheme< C, Rc>, @@ -237,8 +240,8 @@ mod halo2 { EccChip: EccInstructions< 'a, C, - AssignedEcPoint = EccPoint>, - AssignedScalar = AssignedValue, + AssignedEcPoint = EcPoint>, + AssignedScalar = AssignedValue<'a, C::Scalar>, >, { fn from_repr(limbs: Vec>) -> Result { diff --git a/src/system/halo2.rs b/src/system/halo2.rs index 3d7cf140..f0e050d4 100644 --- a/src/system/halo2.rs +++ b/src/system/halo2.rs @@ -137,11 +137,7 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( let instance_committing_key = query_instance.then(|| { instance_committing_key( params, - polynomials - .num_instance() - .into_iter() - .max() - .unwrap_or_default(), + Iterator::max(polynomials.num_instance().into_iter()).unwrap_or_default(), ) }); @@ -618,9 +614,9 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { .iter() .zip( iter::successors( - Some(F::DELTA.pow_vartime(&[(i - * self.permutation_chunk_size) - as u64])), + Some(F::DELTA.pow_vartime([ + (i * self.permutation_chunk_size) as u64, + ])), |delta| Some(F::DELTA * delta), ) .map(Expression::Constant), diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs index faa29207..d240dbba 100644 --- a/src/system/halo2/transcript/halo2.rs +++ b/src/system/halo2/transcript/halo2.rs @@ -417,10 +417,14 @@ impl< mod halo2_lib { use crate::system::halo2::transcript::halo2::EncodeNative; - use halo2_curves::CurveAffine; + use halo2_curves::{BigPrimeField, CurveAffine}; use halo2_ecc::ecc::BaseFieldEccChip; - impl<'a, 'b, C: CurveAffine> EncodeNative<'a, C, C::Scalar> for BaseFieldEccChip<'b, C> { + impl<'a, 'b, C: CurveAffine> EncodeNative<'a, C, C::Scalar> for BaseFieldEccChip<'b, C> + where + C::Scalar: BigPrimeField, + C::Base: BigPrimeField, + { fn encode_native( &self, _: &mut Self::Context, diff --git a/src/util/arithmetic.rs b/src/util/arithmetic.rs index b9a5c7c6..681a54d5 100644 --- a/src/util/arithmetic.rs +++ b/src/util/arithmetic.rs @@ -15,7 +15,7 @@ pub use halo2_curves::{ Curve, Group, GroupEncoding, }, pairing::MillerLoopResult, - Coordinates, CurveAffine, CurveExt, FieldExt, + BigPrimeField, Coordinates, CurveAffine, CurveExt, FieldExt, }; pub trait MultiMillerLoop: halo2_curves::pairing::MultiMillerLoop + Debug {} @@ -128,8 +128,8 @@ impl Domain { pub fn rotate_scalar(&self, scalar: F, rotation: Rotation) -> F { match rotation.0.cmp(&0) { Ordering::Equal => scalar, - Ordering::Greater => scalar * self.gen.pow_vartime(&[rotation.0 as u64]), - Ordering::Less => scalar * self.gen_inv.pow_vartime(&[(-rotation.0) as u64]), + Ordering::Greater => scalar * self.gen.pow_vartime([rotation.0 as u64]), + Ordering::Less => scalar * self.gen_inv.pow_vartime([(-rotation.0) as u64]), } } } diff --git a/src/verifier/plonk.rs b/src/verifier/plonk.rs index c3276af4..e5740493 100644 --- a/src/verifier/plonk.rs +++ b/src/verifier/plonk.rs @@ -451,11 +451,8 @@ where (min, max) } }); - let max_instance_len = instances - .iter() - .map(|instance| instance.len()) - .max() - .unwrap_or_default(); + let max_instance_len = + Iterator::max(instances.iter().map(|instance| instance.len())).unwrap_or_default(); -max_rotation..max_instance_len as i32 + min_rotation.abs() }); protocol From f050966ba4f918c6d6bb9a59c0753b119c9e9e9f Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Sun, 27 Nov 2022 16:59:50 -0500 Subject: [PATCH 21/28] wip: recursion example currently works if circuit does not use simple selectors * Simple selector compression is not correctly handled by `CsProxy` at the moment --- Cargo.toml | 6 +-- configs/verify_circuit.config | 2 +- examples/recursion.rs | 78 ++++++++++++++++++++++++++++------- src/loader/halo2/shim.rs | 33 +++++++++------ src/pcs/kzg/decider.rs | 10 ++++- 5 files changed, 95 insertions(+), 34 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 68e04568..a1e0a716 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,8 +37,8 @@ bytes = { version = "1.2", optional = true } rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } # loader_halo2 -halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_base", default-features = false, optional = true } -halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_ecc", default-features = false, optional = true } +halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "faster-and-compatible", package = "halo2_base", default-features = false, optional = true } +halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "faster-and-compatible", package = "halo2_ecc", default-features = false, optional = true } halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc", optional = true } poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true } @@ -48,7 +48,7 @@ rand_chacha = "0.3.1" paste = "1.0.7" # system_halo2 -halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_ecc", default-features = false } +halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "faster-and-compatible", package = "halo2_ecc" } halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc" } # loader_evm diff --git a/configs/verify_circuit.config b/configs/verify_circuit.config index c7436f39..b146ac61 100644 --- a/configs/verify_circuit.config +++ b/configs/verify_circuit.config @@ -1 +1 @@ -{"strategy":"Simple","degree":21,"num_advice":10,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file +{"strategy":"Simple","degree":21,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/examples/recursion.rs b/examples/recursion.rs index d2cbaabf..6db60cbf 100644 --- a/examples/recursion.rs +++ b/examples/recursion.rs @@ -1,5 +1,6 @@ #![allow(clippy::type_complexity)] +use ark_std::{end_timer, start_timer}; use common::*; use halo2_base::utils::fs::gen_srs; use halo2_curves::{ @@ -45,6 +46,8 @@ use rand_chacha::{ }; use std::{fs, iter, marker::PhantomData, rc::Rc}; +use crate::recursion::AggregationConfigParams; + const LIMBS: usize = 3; const BITS: usize = 88; const T: usize = 5; @@ -137,6 +140,11 @@ mod common { fn accumulator_indices() -> Option> { None } + + /// Output the simple selector columns (before selector compression) of the circuit + fn selectors(config: &Self::Config) -> Vec { + vec![] + } } pub fn gen_pk>(params: &ParamsKZG, circuit: &C) -> ProvingKey { @@ -151,9 +159,11 @@ mod common { instances: Vec>, ) -> Vec { if params.k() > 3 { + let mock = start_timer!(|| "Mock prover"); MockProver::run(params.k(), &circuit, instances.clone()) .unwrap() .assert_satisfied(); + end_timer!(mock); } let instances = instances.iter().map(Vec::as_slice).collect_vec(); @@ -214,7 +224,7 @@ mod common { ) -> Snark { struct CsProxy(PhantomData<(F, C)>); - impl> Circuit for CsProxy { + impl> Circuit for CsProxy { type Config = C::Config; type FloorPlanner = C::FloorPlanner; @@ -226,7 +236,21 @@ mod common { C::configure(meta) } - fn synthesize(&self, _: Self::Config, _: impl Layouter) -> Result<(), Error> { + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // when `C` has simple selectors, we tell `CsProxy` not to over-optimize the selectors (e.g., compressing them all into one) by turning all selectors on in the first row + layouter.assign_region( + || "", + |mut region| { + for q in C::selectors(&config).iter() { + q.enable(&mut region, 0)?; + } + Ok(()) + }, + )?; Ok(()) } } @@ -563,10 +587,7 @@ mod recursion { .protocol .preprocessed .iter() - .flat_map(|preprocessed| { - let coordinates = preprocessed.coordinates().unwrap(); - [*coordinates.x(), *coordinates.y()] - }) + .flat_map(|preprocessed| [preprocessed.x, preprocessed.y]) .map(fe_to_fe) .chain(previous.protocol.transcript_initial_state) .collect_vec(); @@ -620,15 +641,13 @@ mod recursion { ) -> Result>>, Error> { let [lhs, rhs] = [self.default_accumulator.lhs, self.default_accumulator.rhs].map(|default| { - loader + let assigned = loader .ecc_chip() .assign_constant(&mut loader.ctx_mut(), default) - .unwrap() + .unwrap(); + loader.ec_point_from_assigned(assigned) }); - Ok(KzgAccumulator::new( - loader.ec_point_from_assigned(lhs), - loader.ec_point_from_assigned(rhs), - )) + Ok(KzgAccumulator::new(lhs, rhs)) } } @@ -652,7 +671,7 @@ mod recursion { let path = std::env::var("VERIFY_CONFIG") .unwrap_or_else(|_| "configs/verify_circuit.config".to_owned()); let params: AggregationConfigParams = serde_json::from_reader( - File::open(path.as_str()).expect(format!("{path} file should exist").as_str()), + File::open(path.as_str()).unwrap_or_else(|err| panic!("{err:?}")), ) .unwrap(); @@ -790,6 +809,10 @@ mod recursion { ] { ctx.region.constrain_equal(lhs.cell(), rhs.cell())?; } + + // IMPORTANT: + config.base_field_config.finalize(&mut ctx); + assigned_instances.extend( [lhs.x(), lhs.y(), rhs.x(), rhs.y()] .into_iter() @@ -801,6 +824,7 @@ mod recursion { }, )?; + assert_eq!(assigned_instances.len(), 4 * LIMBS + 4); for (row, limb) in assigned_instances.into_iter().enumerate() { layouter.constrain_instance(limb, config.instance, row)?; } @@ -822,6 +846,17 @@ mod recursion { fn accumulator_indices() -> Option> { Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) } + + /*fn selectors(config: &Self::Config) -> Vec { + config + .base_field_config + .range + .gate + .basic_gates + .iter() + .map(|gate| gate.q_enable) + .collect() + }*/ } pub fn gen_recursion_pk>( @@ -854,14 +889,17 @@ mod recursion { RecursionCircuit::initial_snark(recursion_params, Some(recursion_pk.get_vk())); for (round, input) in inputs.into_iter().enumerate() { state = app.state_transition(input); + println!("Generate app snark"); + let app_snark = gen_snark(app_params, app_pk, app); let recursion = RecursionCircuit::new( recursion_params, - gen_snark(app_params, app_pk, app), + app_snark, previous, initial_state, state, round, ); + println!("Generate recursion snark"); previous = gen_snark(recursion_params, recursion_pk, recursion); app = ConcreteCircuit::new(state); } @@ -871,16 +909,23 @@ mod recursion { fn main() { let app_params = gen_srs(3); - let recursion_params = gen_srs(22); + let recursion_config: AggregationConfigParams = + serde_json::from_reader(fs::File::open("configs/verify_circuit.config").unwrap()).unwrap(); + let k = recursion_config.degree; + let recursion_params = gen_srs(k); let app_pk = gen_pk(&app_params, &application::Square::default()); + + let pk_time = start_timer!(|| "Generate recursion pk"); let recursion_pk = recursion::gen_recursion_pk::( &recursion_params, &app_params, app_pk.get_vk(), ); + end_timer!(pk_time); - let num_round = 3; + let num_round = 1; + let pf_time = start_timer!(|| "Generate full recursive snark"); let (final_state, snark) = recursion::gen_recursion_snark::( &app_params, &recursion_params, @@ -889,6 +934,7 @@ fn main() { Fr::from(2u64), vec![(); num_round], ); + end_timer!(pf_time); assert_eq!(final_state, Fr::from(2u64).pow(&[1 << num_round, 0, 0, 0])); let accept = { diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 04de3d05..bde101f5 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -335,16 +335,15 @@ mod halo2_lib { impl Deref, )], ) -> Result { + let (scalars, points): (Vec<_>, Vec<_>) = pairs + .iter() + .map(|(scalar, point)| (vec![scalar.deref().clone()], point.deref().clone())) + .unzip(); + Ok(self.multi_scalar_mult::( ctx, - &pairs - .iter() - .map(|(_, point)| point.deref().clone()) - .collect_vec(), - &pairs - .iter() - .map(|(scalar, _)| vec![scalar.deref().clone()]) - .collect_vec(), + &points, + &scalars, ::NUM_BITS as usize, 4, // empirically clump factor of 4 seems to be best )) @@ -355,14 +354,22 @@ mod halo2_lib { ctx: &mut Self::Context, pairs: &[(impl Deref, C)], ) -> Result { + let (scalars, points): (Vec<_>, Vec<_>) = pairs + .iter() + .filter_map(|(scalar, point)| { + if point.is_identity().into() { + None + } else { + Some((vec![scalar.deref().clone()], *point)) + } + }) + .unzip(); + Ok(BaseFieldEccChip::::fixed_base_msm::( self, ctx, - &pairs.iter().map(|(_, point)| *point).collect_vec(), - &pairs - .iter() - .map(|(scalar, _)| vec![scalar.deref().clone()]) - .collect_vec(), + &points, + &scalars, ::NUM_BITS as usize, 0, 4, diff --git a/src/pcs/kzg/decider.rs b/src/pcs/kzg/decider.rs index de3e2a06..0cb8b720 100644 --- a/src/pcs/kzg/decider.rs +++ b/src/pcs/kzg/decider.rs @@ -60,7 +60,15 @@ mod native { ) -> bool { !accumulators .into_iter() - .any(|accumulator| !Self::decide(dk, accumulator)) + //.enumerate() + .any(|accumulator| { + /*let decide = Self::decide(dk, accumulator); + if !decide { + panic!("{i}"); + } + !decide*/ + !Self::decide(dk, accumulator) + }) } } } From 721228337efad93862602d05959ab6fd8c2809c5 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Sun, 27 Nov 2022 17:27:36 -0500 Subject: [PATCH 22/28] fix: got recursion example working with halo2-lib * where halo2-lib uses simple selectors * only works if all selector columns are in active use and can't be optimized away * tricks `CsProxy` by turning all selectors on in row `0` --- Cargo.toml | 6 +++--- configs/verify_circuit.config | 2 +- examples/recursion.rs | 9 ++++++--- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a1e0a716..418d3343 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,8 +37,8 @@ bytes = { version = "1.2", optional = true } rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } # loader_halo2 -halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "faster-and-compatible", package = "halo2_base", default-features = false, optional = true } -halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "faster-and-compatible", package = "halo2_ecc", default-features = false, optional = true } +halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_base", default-features = false, optional = true } +halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_ecc", default-features = false, optional = true } halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc", optional = true } poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true } @@ -48,7 +48,7 @@ rand_chacha = "0.3.1" paste = "1.0.7" # system_halo2 -halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "faster-and-compatible", package = "halo2_ecc" } +halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_ecc" } halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2wrong", tag = "v2022_10_22", package = "ecc" } # loader_evm diff --git a/configs/verify_circuit.config b/configs/verify_circuit.config index b146ac61..e65b2b52 100644 --- a/configs/verify_circuit.config +++ b/configs/verify_circuit.config @@ -1 +1 @@ -{"strategy":"Simple","degree":21,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file +{"strategy":"Simple","degree":21,"num_advice":4,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/examples/recursion.rs b/examples/recursion.rs index 6db60cbf..1b49f4e6 100644 --- a/examples/recursion.rs +++ b/examples/recursion.rs @@ -142,7 +142,7 @@ mod common { } /// Output the simple selector columns (before selector compression) of the circuit - fn selectors(config: &Self::Config) -> Vec { + fn selectors(_: &Self::Config) -> Vec { vec![] } } @@ -242,6 +242,7 @@ mod common { mut layouter: impl Layouter, ) -> Result<(), Error> { // when `C` has simple selectors, we tell `CsProxy` not to over-optimize the selectors (e.g., compressing them all into one) by turning all selectors on in the first row + // currently this only works if all simple selector columns are used in the actual circuit and there are overlaps amongst all enabled selectors (i.e., the actual circuit will not optimize constraint system further) layouter.assign_region( || "", |mut region| { @@ -812,6 +813,8 @@ mod recursion { // IMPORTANT: config.base_field_config.finalize(&mut ctx); + dbg!(ctx.total_advice); + println!("Advice columns used: {}", ctx.advice_alloc[0][0].0 + 1); assigned_instances.extend( [lhs.x(), lhs.y(), rhs.x(), rhs.y()] @@ -847,7 +850,7 @@ mod recursion { Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) } - /*fn selectors(config: &Self::Config) -> Vec { + fn selectors(config: &Self::Config) -> Vec { config .base_field_config .range @@ -856,7 +859,7 @@ mod recursion { .iter() .map(|gate| gate.q_enable) .collect() - }*/ + } } pub fn gen_recursion_pk>( From 8b2daae5df30ab374cb8e82e6a3ee487b8f6d296 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Thu, 1 Dec 2022 19:06:24 -0500 Subject: [PATCH 23/28] feat: add `sdk` module for snark aggregation toolkit --- Cargo.toml | 18 +- benches/standard_plonk.rs | 261 ++++++++++++++ configs/example_evm_accumulator.config | 1 + src/lib.rs | 1 + src/sdk.rs | 466 +++++++++++++++++++++++++ src/sdk/aggregation.rs | 341 ++++++++++++++++++ src/system/halo2/transcript/halo2.rs | 31 ++ src/util/hash/poseidon.rs | 20 +- 8 files changed, 1132 insertions(+), 7 deletions(-) create mode 100644 benches/standard_plonk.rs create mode 100644 configs/example_evm_accumulator.config create mode 100644 src/sdk.rs create mode 100644 src/sdk/aggregation.rs diff --git a/Cargo.toml b/Cargo.toml index 76989bb1..3fecf432 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,12 @@ num-bigint = "0.4.3" num-integer = "0.1.45" num-traits = "0.2.15" rand = "0.8" +rand_chacha = "0.3.1" hex = "0.4" serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = "1.0" +bincode = "1.3.3" +ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } # Use halo2_base as non-optional dependency because it re-exports halo2_proofs, halo2curves, and poseidon, using different repos based on feature flag "halo2-axiom" or "halo2-pse" halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_base", default-features = false } @@ -36,8 +39,10 @@ zkevm_circuits = {git = "https://github.com/privacy-scaling-explorations/zkevm-c [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } -rand_chacha = "0.3.1" paste = "1.0.7" +pprof = { version = "0.11", features = ["criterion", "flamegraph"] } +criterion = "0.4" +criterion-macro = "0.4" # system_halo2 halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_ecc", default-features = false } @@ -47,8 +52,8 @@ crossterm = { version = "0.25" } tui = { version = "0.19", default-features = false, features = ["crossterm"] } [features] -default = ["loader_evm", "loader_halo2", "system_halo2", "halo2-axiom", "display"] -display = ["halo2_base/display", "halo2_ecc?/display"] +default = ["loader_evm", "loader_halo2", "system_halo2", "halo2-axiom", "halo2_ecc/jemallocator"] +display = ["halo2_base/display", "halo2_ecc?/display", "ark-std"] # EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo halo2-pse = ["halo2_base/halo2-pse", "halo2_ecc?/halo2-pse"] halo2-axiom = ["halo2_base/halo2-axiom", "halo2_ecc?/halo2-axiom"] @@ -73,6 +78,11 @@ required-features = ["loader_halo2", "system_halo2", "loader_evm"] name = "recursion" required-features = ["loader_halo2", "system_halo2"] +[[bench]] +name = "standard_plonk" +required-features = ["loader_halo2", "system_halo2"] +harness = false + [profile.dev] opt-level = 3 diff --git a/benches/standard_plonk.rs b/benches/standard_plonk.rs new file mode 100644 index 00000000..6b89d945 --- /dev/null +++ b/benches/standard_plonk.rs @@ -0,0 +1,261 @@ +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; +use plonk_verifier::sdk::aggregation::AggregationCircuit; +use plonk_verifier::sdk::{ + self, gen_pk, gen_proof_shplonk, gen_snark_shplonk, PoseidonTranscript, POSEIDON_SPEC, +}; +use pprof::criterion::{Output, PProfProfiler}; + +use ark_std::{end_timer, start_timer}; + +use halo2_base::halo2_proofs; +use halo2_proofs::halo2curves as halo2_curves; +use halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::{ + commitment::{Params, ParamsProver}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::{ProverGWC, VerifierGWC}, + strategy::AccumulatorStrategy, + }, + VerificationStrategy, + }, + transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use itertools::Itertools; +use plonk_verifier::{ + loader::{ + evm::{encode_calldata, EvmLoader, ExecutorBuilder}, + native::NativeLoader, + }, + pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, + verifier::{self, PlonkVerifier}, +}; +use rand::rngs::OsRng; +use rand::SeedableRng; +use rand_chacha::ChaCha20Rng; +use std::{io::Cursor, rc::Rc}; + +mod application { + use super::halo2_curves::bn256::Fr; + use super::halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, + poly::Rotation, + }; + use halo2_base::halo2_proofs::plonk::Assigned; + use plonk_verifier::sdk::CircuitExt; + use rand::RngCore; + + #[derive(Clone, Copy)] + pub struct StandardPlonkConfig { + a: Column, + b: Column, + c: Column, + q_a: Column, + q_b: Column, + q_c: Column, + q_ab: Column, + constant: Column, + #[allow(dead_code)] + instance: Column, + } + + impl StandardPlonkConfig { + fn configure(meta: &mut ConstraintSystem) -> Self { + let [a, b, c] = [(); 3].map(|_| meta.advice_column()); + let [q_a, q_b, q_c, q_ab, constant] = [(); 5].map(|_| meta.fixed_column()); + let instance = meta.instance_column(); + + [a, b, c].map(|column| meta.enable_equality(column)); + + meta.create_gate( + "q_a·a + q_b·b + q_c·c + q_ab·a·b + constant + instance = 0", + |meta| { + let [a, b, c] = + [a, b, c].map(|column| meta.query_advice(column, Rotation::cur())); + let [q_a, q_b, q_c, q_ab, constant] = [q_a, q_b, q_c, q_ab, constant] + .map(|column| meta.query_fixed(column, Rotation::cur())); + let instance = meta.query_instance(instance, Rotation::cur()); + Some( + q_a * a.clone() + + q_b * b.clone() + + q_c * c + + q_ab * a * b + + constant + + instance, + ) + }, + ); + + StandardPlonkConfig { a, b, c, q_a, q_b, q_c, q_ab, constant, instance } + } + } + + #[derive(Clone, Default)] + pub struct StandardPlonk(Fr); + + impl StandardPlonk { + pub fn rand(mut rng: R) -> Self { + Self(Fr::from(rng.next_u32() as u64)) + } + } + + impl CircuitExt for StandardPlonk { + fn num_instance() -> Vec { + vec![1] + } + + fn instances(&self) -> Vec> { + vec![vec![self.0]] + } + } + + impl Circuit for StandardPlonk { + type Config = StandardPlonkConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + meta.set_minimum_degree(4); + StandardPlonkConfig::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "", + |mut region| { + #[cfg(feature = "halo2-pse")] + { + region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; + + region.assign_advice( + || "", + config.a, + 1, + || Value::known(-Fr::from(5u64)), + )?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed( + || "", + column, + 1, + || Value::known(Fr::from(idx as u64)), + )?; + } + + let a = + region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; + a.copy_advice(|| "", &mut region, config.b, 3)?; + a.copy_advice(|| "", &mut region, config.c, 4)?; + } + #[cfg(feature = "halo2-axiom")] + { + region.assign_advice( + config.a, + 0, + Value::known(Assigned::Trivial(self.0)), + )?; + region.assign_fixed(config.q_a, 0, Assigned::Trivial(-Fr::one())); + + region.assign_advice( + config.a, + 1, + Value::known(Assigned::Trivial(-Fr::from(5u64))), + )?; + for (idx, column) in (1..).zip([ + config.q_a, + config.q_b, + config.q_c, + config.q_ab, + config.constant, + ]) { + region.assign_fixed(column, 1, Assigned::Trivial(Fr::from(idx as u64))); + } + + let a = region.assign_advice( + config.a, + 2, + Value::known(Assigned::Trivial(Fr::one())), + )?; + a.copy_advice(&mut region, config.b, 3); + a.copy_advice(&mut region, config.c, 4); + } + + Ok(()) + }, + ) + } + } +} + +fn gen_application_snark( + params: &ParamsKZG, + transcript: &mut PoseidonTranscript>, +) -> sdk::Snark { + let circuit = application::StandardPlonk::rand(OsRng); + + let pk = gen_pk(params, &circuit, None); + gen_snark_shplonk(params, &pk, circuit, transcript, None) +} + +fn bench(c: &mut Criterion) { + std::env::set_var("VERIFY_CONFIG", "./configs/example_evm_accumulator.config"); + let k = 21; + let params = halo2_base::utils::fs::gen_srs(k); + let params_app = { + let mut params = params.clone(); + params.downsize(8); + params + }; + + let mut transcript = + PoseidonTranscript::::from_spec(vec![], POSEIDON_SPEC.clone()); + let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app, &mut transcript)); + + let start1 = start_timer!(|| "Create aggregation circuit"); + let mut rng = ChaCha20Rng::from_entropy(); + let agg_circuit = AggregationCircuit::new(¶ms, snarks, &mut transcript, &mut rng); + end_timer!(start1); + + let pk = gen_pk(¶ms, &agg_circuit, None); + + let mut group = c.benchmark_group("plonk-prover"); + group.sample_size(10); + group.bench_with_input( + BenchmarkId::new("standard-plonk-agg", k), + &(¶ms, &pk, &agg_circuit), + |b, &(params, pk, agg_circuit)| { + b.iter(|| { + let instances = agg_circuit.instances(); + gen_proof_shplonk(params, pk, agg_circuit.clone(), instances, &mut transcript, None) + }) + }, + ); + group.finish(); +} + +criterion_group! { + name = benches; + config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); + targets = bench +} +criterion_main!(benches); diff --git a/configs/example_evm_accumulator.config b/configs/example_evm_accumulator.config new file mode 100644 index 00000000..e65b2b52 --- /dev/null +++ b/configs/example_evm_accumulator.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":21,"num_advice":4,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index a6e4b3e6..3878c019 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ pub mod cost; pub mod loader; pub mod pcs; +pub mod sdk; pub mod system; pub mod util; pub mod verifier; diff --git a/src/sdk.rs b/src/sdk.rs new file mode 100644 index 00000000..520e9f8f --- /dev/null +++ b/src/sdk.rs @@ -0,0 +1,466 @@ +#![allow(clippy::type_complexity)] +use crate::cost::CostEstimation; +use crate::halo2_proofs; +use crate::pcs::MultiOpenScheme; +use crate::{ + loader::native::NativeLoader, + pcs, + poseidon::Spec, + system::halo2::{self, compile, Config}, + util::{hash, transcript::TranscriptWrite}, + verifier::PlonkProof, + Protocol, +}; +#[cfg(feature = "display")] +use ark_std::{end_timer, start_timer}; +use halo2_proofs::{ + circuit::{Layouter, Value}, + dev::MockProver, + halo2curves::{ + bn256::{Bn256, Fr, G1Affine}, + group::ff::Field, + }, + plonk::{ + create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ConstraintSystem, Error, + ProvingKey, Selector, VerifyingKey, + }, + poly::{ + commitment::{Params, ParamsProver, Prover, Verifier}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + msm::DualMSM, + multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, + strategy::{AccumulatorStrategy, GuardKZG}, + }, + VerificationStrategy, + }, +}; +use itertools::Itertools; +use lazy_static::lazy_static; +use rand::SeedableRng; +use rand_chacha::ChaCha20Rng; +use std::{ + fs::{self, File}, + io::{BufReader, BufWriter}, + iter, + marker::PhantomData, + path::Path, +}; + +pub mod aggregation; + +// Poseidon parameters +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +pub type PoseidonTranscript = + halo2::transcript::halo2::PoseidonTranscript; +lazy_static! { + pub static ref POSEIDON_SPEC: Spec = Spec::new(R_F, R_P); +} + +pub struct Snark { + pub protocol: Protocol, + pub instances: Vec>, + pub proof: Vec, +} + +impl Snark { + pub fn new(protocol: Protocol, instances: Vec>, proof: Vec) -> Self { + Self { protocol, instances, proof } + } +} + +impl From for SnarkWitness { + fn from(snark: Snark) -> Self { + Self { + protocol: snark.protocol, + instances: snark + .instances + .into_iter() + .map(|instances| instances.into_iter().map(Value::known).collect_vec()) + .collect(), + proof: Value::known(snark.proof), + } + } +} + +#[derive(Clone)] +pub struct SnarkWitness { + pub protocol: Protocol, + pub instances: Vec>>, + pub proof: Value>, +} + +impl SnarkWitness { + pub fn without_witnesses(&self) -> Self { + SnarkWitness { + protocol: self.protocol.clone(), + instances: self + .instances + .iter() + .map(|instances| vec![Value::unknown(); instances.len()]) + .collect(), + proof: Value::unknown(), + } + } + + pub fn proof(&self) -> Value<&[u8]> { + self.proof.as_ref().map(Vec::as_slice) + } +} + +pub trait CircuitExt: Circuit { + fn num_instance() -> Vec; + + fn instances(&self) -> Vec>; + + fn accumulator_indices() -> Option> { + None + } + + /// Output the simple selector columns (before selector compression) of the circuit + fn selectors(_: &Self::Config) -> Vec { + vec![] + } +} + +pub fn gen_pk>( + params: &ParamsKZG, + circuit: &C, + path: Option<&Path>, +) -> ProvingKey { + if let Some(path) = path { + match File::open(path) { + Ok(f) => { + #[cfg(feature = "display")] + let read_time = start_timer!(|| format!("Reading pkey from {path:?}")); + + // TODO: bench if BufReader is indeed faster than Read + let mut bufreader = BufReader::new(f); + let pk = ProvingKey::read::<_, C>(&mut bufreader, params) + .expect("Reading pkey should not fail"); + + #[cfg(feature = "display")] + end_timer!(read_time); + + pk + } + Err(_) => { + #[cfg(feature = "display")] + let pk_time = start_timer!(|| "Generating vkey & pkey"); + + let vk = keygen_vk(params, circuit).unwrap(); + let pk = keygen_pk(params, vk, circuit).unwrap(); + + #[cfg(feature = "display")] + end_timer!(pk_time); + + #[cfg(feature = "display")] + let write_time = start_timer!(|| format!("Writing pkey to {path:?}")); + + path.parent().and_then(|dir| fs::create_dir_all(dir).ok()).unwrap(); + let mut f = BufWriter::new(File::create(path).unwrap()); + pk.write(&mut f).unwrap(); + + #[cfg(feature = "display")] + end_timer!(write_time); + + pk + } + } + } else { + #[cfg(feature = "display")] + let pk_time = start_timer!(|| "Generating vkey & pkey"); + + let vk = keygen_vk(params, circuit).unwrap(); + let pk = keygen_pk(params, vk, circuit).unwrap(); + + #[cfg(feature = "display")] + end_timer!(pk_time); + + pk + } +} + +/// Generates a native proof using either SHPLONK or GWC proving method. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_proof<'params, C, P, V>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: C, + instances: Vec>, + transcript: &mut PoseidonTranscript>, + path: Option<(&Path, &Path)>, +) -> Vec +where + C: Circuit, + P: Prover<'params, KZGCommitmentScheme>, + V: Verifier< + 'params, + KZGCommitmentScheme, + Guard = GuardKZG<'params, Bn256>, + MSMAccumulator = DualMSM<'params, Bn256>, + >, +{ + #[cfg(debug_assertions)] + { + MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); + } + + let mut proof: Option> = None; + + if let Some((instance_path, proof_path)) = path { + let cached_instances = read_instances(instance_path); + if matches!(cached_instances, Ok(tmp) if tmp == instances) && proof_path.exists() { + #[cfg(feature = "display")] + let read_time = start_timer!(|| format!("Reading proof from {proof_path:?}")); + + proof = Some(fs::read(proof_path).unwrap()); + + #[cfg(feature = "display")] + end_timer!(read_time); + } + } + + let instances = instances.iter().map(Vec::as_slice).collect_vec(); + + let proof = proof.unwrap_or_else(|| { + #[cfg(feature = "display")] + let proof_time = start_timer!(|| "Create proof"); + + transcript.clear(); + create_proof::<_, P, _, _, _, _>( + params, + pk, + &[circuit], + &[&instances], + &mut ChaCha20Rng::from_entropy(), + transcript, + ) + .unwrap(); + let proof = transcript.stream_mut().split_off(0); + + #[cfg(feature = "display")] + end_timer!(proof_time); + + if let Some((instance_path, proof_path)) = path { + write_instances(&instances, instance_path); + fs::write(proof_path, &proof).unwrap(); + } + proof + }); + + debug_assert!({ + let mut transcript = PoseidonTranscript::::new(proof.as_slice()); + VerificationStrategy::<_, V>::finalize( + verify_proof::<_, V, _, _, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }); + + proof +} + +pub fn gen_proof_gwc>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + transcript: &mut PoseidonTranscript>, + path: Option<(&Path, &Path)>, +) -> Vec { + gen_proof::, VerifierGWC<_>>(params, pk, circuit, instances, transcript, path) +} + +pub fn gen_proof_shplonk>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + transcript: &mut PoseidonTranscript>, + path: Option<(&Path, &Path)>, +) -> Vec { + gen_proof::, VerifierSHPLONK<_>>( + params, pk, circuit, instances, transcript, path, + ) +} + +pub fn gen_snark<'params, ConcreteCircuit, P, V>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: ConcreteCircuit, + transcript: &mut PoseidonTranscript>, + path: Option<(&Path, &Path)>, +) -> Snark +where + ConcreteCircuit: CircuitExt, + P: Prover<'params, KZGCommitmentScheme>, + V: Verifier< + 'params, + KZGCommitmentScheme, + Guard = GuardKZG<'params, Bn256>, + MSMAccumulator = DualMSM<'params, Bn256>, + >, +{ + let protocol = compile( + params, + pk.get_vk(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + + let instances = circuit.instances(); + let proof = gen_proof::( + params, + pk, + circuit, + instances.clone(), + transcript, + path, + ); + + Snark::new(protocol, instances, proof) +} + +pub fn gen_snark_gwc>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + transcript: &mut PoseidonTranscript>, + path: Option<(&Path, &Path)>, +) -> Snark { + gen_snark::, VerifierGWC<_>>( + params, pk, circuit, transcript, path, + ) +} + +pub fn gen_snark_shplonk>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + transcript: &mut PoseidonTranscript>, + path: Option<(&Path, &Path)>, +) -> Snark { + gen_snark::, VerifierSHPLONK<_>>( + params, pk, circuit, transcript, path, + ) +} + +pub fn read_instances(path: impl AsRef) -> Result>, bincode::Error> { + let f = File::open(path)?; + let reader = BufReader::new(f); + let instances: Vec> = bincode::deserialize_from(reader)?; + instances + .into_iter() + .map(|instance_column| { + instance_column + .iter() + .map(|bytes| { + Option::from(Fr::from_bytes(bytes)).ok_or(Box::new(bincode::ErrorKind::Custom( + "Invalid finite field point".to_owned(), + ))) + }) + .collect::, _>>() + }) + .collect() +} + +pub fn write_instances(instances: &[&[Fr]], path: impl AsRef) { + let instances: Vec> = instances + .iter() + .map(|instance_column| instance_column.iter().map(|x| x.to_bytes()).collect_vec()) + .collect_vec(); + let f = BufWriter::new(File::create(path).unwrap()); + bincode::serialize_into(f, &instances).unwrap(); +} + +pub fn gen_dummy_snark( + params: &ParamsKZG, + vk: Option<&VerifyingKey>, +) -> Snark +where + ConcreteCircuit: CircuitExt, + MOS: MultiOpenScheme + + CostEstimation>>, +{ + struct CsProxy(PhantomData<(F, C)>); + + impl> Circuit for CsProxy { + type Config = C::Config; + type FloorPlanner = C::FloorPlanner; + + fn without_witnesses(&self) -> Self { + CsProxy(PhantomData) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + C::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // when `C` has simple selectors, we tell `CsProxy` not to over-optimize the selectors (e.g., compressing them all into one) by turning all selectors on in the first row + // currently this only works if all simple selector columns are used in the actual circuit and there are overlaps amongst all enabled selectors (i.e., the actual circuit will not optimize constraint system further) + layouter.assign_region( + || "", + |mut region| { + for q in C::selectors(&config).iter() { + q.enable(&mut region, 0)?; + } + Ok(()) + }, + )?; + Ok(()) + } + } + + let dummy_vk = vk + .is_none() + .then(|| keygen_vk(params, &CsProxy::(PhantomData)).unwrap()); + let protocol = compile( + params, + vk.or(dummy_vk.as_ref()).unwrap(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + let instances = ConcreteCircuit::num_instance() + .into_iter() + .map(|n| iter::repeat(Fr::default()).take(n).collect()) + .collect(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + for _ in 0..protocol + .num_witness + .iter() + .chain(Some(&protocol.quotient.num_chunk())) + .sum::() + { + transcript.write_ec_point(G1Affine::default()).unwrap(); + } + for _ in 0..protocol.evaluations.len() { + transcript.write_scalar(Fr::default()).unwrap(); + } + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..MOS::estimate_cost(&queries).num_commitment { + transcript.write_ec_point(G1Affine::default()).unwrap(); + } + transcript.finalize() + }; + + Snark::new(protocol, instances, proof) +} diff --git a/src/sdk/aggregation.rs b/src/sdk/aggregation.rs new file mode 100644 index 00000000..bd76c116 --- /dev/null +++ b/src/sdk/aggregation.rs @@ -0,0 +1,341 @@ +use crate::halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; +use crate::halo2_proofs::{ + circuit::{Cell, Layouter, SimpleFloorPlanner, Value}, + plonk::{self, Circuit, Column, ConstraintSystem, Instance}, + poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, +}; +use crate::pcs::kzg::Bdfg21; +use crate::{ + loader::{self, native::NativeLoader}, + pcs::{ + kzg::{Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, + AccumulationScheme, AccumulationSchemeProver, MultiOpenScheme, PolynomialCommitmentScheme, + }, + system, + util::arithmetic::fe_to_limbs, + verifier::{self, PlonkVerifier}, + Protocol, +}; +#[cfg(feature = "display")] +use ark_std::{end_timer, start_timer}; +use halo2_base::{AssignedValue, Context, ContextParams}; +use halo2_ecc::ecc::EccChip; +use itertools::Itertools; +use rand::rngs::OsRng; +use rand::{Rng, SeedableRng}; +use rand_chacha::ChaCha20Rng; +use std::{fs::File, rc::Rc}; + +const LIMBS: usize = 3; +const BITS: usize = 88; + +use super::{PoseidonTranscript, Snark, SnarkWitness, POSEIDON_SPEC}; + +type Svk = KzgSuccinctVerifyingKey; +type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; +type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; +/// PCS be either `Kzg` or `Kzg` +type Plonk = verifier::Plonk>; +type Shplonk = Plonk>; + +/// Core function used in `synthesize` to aggregate multiple `snarks`. +/// +/// Returns the assigned instances of previous snarks (all concatenated together) and the new final pair that needs to be verified in a pairing check +pub fn aggregate<'a, PCS>( + svk: &PCS::SuccinctVerifyingKey, + loader: &Rc>, + snarks: &[SnarkWitness], + as_proof: Value<&'_ [u8]>, +) -> ( + Vec>, + KzgAccumulator>>, +) +where + PCS: PolynomialCommitmentScheme< + G1Affine, + Rc>, + Accumulator = KzgAccumulator>>, + > + MultiOpenScheme>>, +{ + let assign_instances = |instances: &[Vec>]| { + instances + .iter() + .map(|instances| { + instances.iter().map(|instance| loader.assign_scalar(*instance)).collect_vec() + }) + .collect_vec() + }; + + // TODO pre-allocate capacity better + let mut previous_instances = vec![]; + let mut transcript = PoseidonTranscript::>, _>::from_spec( + loader, + Value::unknown(), + POSEIDON_SPEC.clone(), + ); + + let mut accumulators = snarks + .iter() + .flat_map(|snark| { + let protocol = snark.protocol.loaded(loader); + // TODO use 1d vector + let instances = assign_instances(&snark.instances); + + // read the transcript and perform Fiat-Shamir + // run through verification computation and produce the final pair `succinct` + transcript.new_stream(snark.proof()); + let proof = + Plonk::::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + let accumulator = + Plonk::::succinct_verify(svk, &protocol, &instances, &proof).unwrap(); + + previous_instances.extend(instances.into_iter().flatten()); + + accumulator + }) + .collect_vec(); + + let acccumulator = if accumulators.len() > 1 { + transcript.new_stream(as_proof); + let proof = + KzgAs::::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); + KzgAs::::verify(&Default::default(), &accumulators, &proof).unwrap() + } else { + accumulators.pop().unwrap() + }; + + (previous_instances, acccumulator) +} + +#[derive(serde::Serialize, serde::Deserialize)] +pub struct AggregationConfigParams { + pub strategy: halo2_ecc::fields::fp::FpStrategy, + pub degree: u32, + pub num_advice: usize, + pub num_lookup_advice: usize, + pub num_fixed: usize, + pub lookup_bits: usize, + pub limb_bits: usize, + pub num_limbs: usize, +} + +#[derive(Clone)] +pub struct AggregationConfig { + pub base_field_config: halo2_ecc::fields::fp::FpConfig, + pub instance: Column, +} + +impl AggregationConfig { + pub fn configure(meta: &mut ConstraintSystem, params: AggregationConfigParams) -> Self { + assert!( + params.limb_bits == BITS && params.num_limbs == LIMBS, + "For now we fix limb_bits = {}, otherwise change code", + BITS + ); + let base_field_config = halo2_ecc::fields::fp::FpConfig::configure( + meta, + params.strategy, + &[params.num_advice], + &[params.num_lookup_advice], + params.num_fixed, + params.lookup_bits, + BITS, + LIMBS, + halo2_base::utils::modulus::(), + 0, + params.degree as usize, + ); + + let instance = meta.instance_column(); + meta.enable_equality(instance); + + Self { base_field_config, instance } + } + + pub fn range(&self) -> &halo2_base::gates::range::RangeConfig { + &self.base_field_config.range + } + + pub fn gate(&self) -> &halo2_base::gates::flex_gate::FlexGateConfig { + &self.base_field_config.range.gate + } + + pub fn ecc_chip(&self) -> halo2_ecc::ecc::BaseFieldEccChip { + EccChip::construct(self.base_field_config.clone()) + } +} + +/// Aggregation circuit that does not re-expose any public inputs from aggregated snarks +#[derive(Clone)] +pub struct AggregationCircuit { + svk: Svk, + snarks: Vec, + instances: Vec, + as_proof: Value>, +} + +impl AggregationCircuit { + pub fn new( + params: &ParamsKZG, + snarks: impl IntoIterator, + transcript_write: &mut PoseidonTranscript>, + rng: &mut impl Rng, + ) -> Self { + let svk = params.get_g()[0].into(); + let snarks = snarks.into_iter().collect_vec(); + + // TODO: this is all redundant calculation to get the public output + // Halo2 should just be able to expose public output to instance column directly + let mut transcript_read = + PoseidonTranscript::::from_spec(&[], POSEIDON_SPEC.clone()); + let accumulators = snarks + .iter() + .flat_map(|snark| { + transcript_read.new_stream(snark.proof.as_slice()); + let proof = Shplonk::read_proof( + &svk, + &snark.protocol, + &snark.instances, + &mut transcript_read, + ) + .unwrap(); + Shplonk::succinct_verify(&svk, &snark.protocol, &snark.instances, &proof).unwrap() + }) + .collect_vec(); + + let (accumulator, as_proof) = { + transcript_write.clear(); + // We always use SHPLONK for accumulation scheme when aggregating proofs + let accumulator = KzgAs::>::create_proof( + &Default::default(), + &accumulators, + transcript_write, + rng, + ) + .unwrap(); + (accumulator, transcript_write.stream_mut().split_off(0)) + }; + + let KzgAccumulator { lhs, rhs } = accumulator; + let instances = [lhs.x, lhs.y, rhs.x, rhs.y].map(fe_to_limbs::<_, _, LIMBS, BITS>).concat(); + + Self { + svk, + snarks: snarks.into_iter().map_into().collect(), + instances, + as_proof: Value::known(as_proof), + } + } + + pub fn accumulator_indices() -> Vec<(usize, usize)> { + (0..4 * LIMBS).map(|idx| (0, idx)).collect() + } + + pub fn num_instance() -> Vec { + vec![4 * LIMBS] + } + + pub fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + pub fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) + } +} + +impl Circuit for AggregationCircuit { + type Config = AggregationConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + Self { + svk: self.svk, + snarks: self.snarks.iter().map(SnarkWitness::without_witnesses).collect(), + instances: Vec::new(), + as_proof: Value::unknown(), + } + } + + fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { + let path = std::env::var("VERIFY_CONFIG") + .unwrap_or_else(|_| "configs/verify_circuit.config".to_owned()); + let params: AggregationConfigParams = serde_json::from_reader( + File::open(path.as_str()).unwrap_or_else(|_| panic!("{path:?} does not exist")), + ) + .unwrap(); + + AggregationConfig::configure(meta, params) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), plonk::Error> { + config.range().load_lookup_table(&mut layouter)?; + + // assume using simple floor planner + let mut first_pass = halo2_base::SKIP_FIRST_PASS; + let mut assigned_instances = vec![]; + + layouter.assign_region( + || "", + |region| { + if first_pass { + first_pass = false; + return Ok(()); + } + #[cfg(feature = "display")] + let witness_time = start_timer!(|| "Witness Collection"); + let ctx = Context::new( + region, + ContextParams { + max_rows: config.gate().max_rows, + num_advice: vec![config.gate().num_advice], + fixed_columns: config.gate().constants.clone(), + }, + ); + + let ecc_chip = config.ecc_chip(); + let loader = Halo2Loader::new(ecc_chip, ctx); + let (_, KzgAccumulator { lhs, rhs }) = aggregate::>( + &self.svk, + &loader, + &self.snarks, + self.as_proof(), + ); + + let lhs = lhs.assigned(); + let rhs = rhs.assigned(); + + config.base_field_config.finalize(&mut loader.ctx_mut()); + #[cfg(feature = "display")] + println!("Total advice cells: {}", loader.ctx().total_advice); + #[cfg(feature = "display")] + println!("Advice columns used: {}", loader.ctx().advice_alloc[0][0].0 + 1); + + assigned_instances = lhs + .x + .truncation + .limbs + .iter() + .chain(lhs.y.truncation.limbs.iter()) + .chain(rhs.x.truncation.limbs.iter()) + .chain(rhs.y.truncation.limbs.iter()) + .map(|assigned| assigned.cell().clone()) + .collect_vec(); + #[cfg(feature = "display")] + end_timer!(witness_time); + Ok(()) + }, + )?; + + // Expose instances + // TODO: use less instances by following Scroll's strategy of keeping only last bit of y coordinate + for (i, cell) in assigned_instances.into_iter().enumerate() { + layouter.constrain_instance(cell, config.instance, i); + } + Ok(()) + } +} diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs index aafdf925..2b564648 100644 --- a/src/system/halo2/transcript/halo2.rs +++ b/src/system/halo2/transcript/halo2.rs @@ -59,6 +59,24 @@ where let buf = Poseidon::new(loader, R_F, R_P); Self { loader: loader.clone(), stream, buf } } + + pub fn from_spec( + loader: &Rc>, + stream: Value, + spec: crate::poseidon::Spec, + ) -> Self { + let buf = Poseidon::from_spec(loader, spec); + Self { loader: loader.clone(), stream, buf } + } + + pub fn clear(&mut self) { + self.buf.clear(); + } + + pub fn new_stream(&mut self, stream: Value) { + self.buf.clear(); + self.stream = stream; + } } impl<'a, C, R, EccChip, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize> @@ -149,6 +167,19 @@ impl Self { Self { loader: NativeLoader, stream, buf: Poseidon::new(&NativeLoader, R_F, R_P) } } + + pub fn from_spec(stream: S, spec: crate::poseidon::Spec) -> Self { + Self { loader: NativeLoader, stream, buf: Poseidon::from_spec(&NativeLoader, spec) } + } + + pub fn clear(&mut self) { + self.buf.clear(); + } + + pub fn new_stream(&mut self, stream: S) { + self.buf.clear(); + self.stream = stream; + } } impl diff --git a/src/util/hash/poseidon.rs b/src/util/hash/poseidon.rs index 5743cc32..fa7442f4 100644 --- a/src/util/hash/poseidon.rs +++ b/src/util/hash/poseidon.rs @@ -5,6 +5,7 @@ use crate::{ }; use std::{iter, marker::PhantomData, mem}; +#[derive(Clone)] struct State { inner: [L; T], _marker: PhantomData, @@ -83,21 +84,34 @@ impl, const T: usize, const RATE: usize> State { spec: Spec, + default_state: State, state: State, buf: Vec, } impl, const T: usize, const RATE: usize> Poseidon { pub fn new(loader: &L::Loader, r_f: usize, r_p: usize) -> Self { + let default_state = + State::new(poseidon::State::default().words().map(|state| loader.load_const(&state))); Self { spec: Spec::new(r_f, r_p), - state: State::new( - poseidon::State::default().words().map(|state| loader.load_const(&state)), - ), + state: default_state.clone(), + default_state, buf: Vec::new(), } } + pub fn from_spec(loader: &L::Loader, spec: Spec) -> Self { + let default_state = + State::new(poseidon::State::default().words().map(|state| loader.load_const(&state))); + Self { spec, state: default_state.clone(), default_state, buf: Vec::new() } + } + + pub fn clear(&mut self) { + self.state = self.default_state.clone(); + self.buf.clear(); + } + pub fn update(&mut self, elements: &[L]) { self.buf.extend_from_slice(elements); } From fc1b4f36b17fa2a41849dd2acce5aebacf1eedc3 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Thu, 1 Dec 2022 23:56:09 -0500 Subject: [PATCH 24/28] chore: move `aggregation` under `halo2` inside `sdk` --- benches/standard_plonk.rs | 9 ++++----- src/sdk.rs | 9 +++++---- src/sdk/evm.rs | 0 src/sdk/halo2.rs | 1 + src/sdk/{ => halo2}/aggregation.rs | 2 +- 5 files changed, 11 insertions(+), 10 deletions(-) create mode 100644 src/sdk/evm.rs create mode 100644 src/sdk/halo2.rs rename src/sdk/{ => halo2}/aggregation.rs (99%) diff --git a/benches/standard_plonk.rs b/benches/standard_plonk.rs index 6b89d945..edd3ad0a 100644 --- a/benches/standard_plonk.rs +++ b/benches/standard_plonk.rs @@ -1,13 +1,8 @@ use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; -use plonk_verifier::sdk::aggregation::AggregationCircuit; -use plonk_verifier::sdk::{ - self, gen_pk, gen_proof_shplonk, gen_snark_shplonk, PoseidonTranscript, POSEIDON_SPEC, -}; use pprof::criterion::{Output, PProfProfiler}; use ark_std::{end_timer, start_timer}; - use halo2_base::halo2_proofs; use halo2_proofs::halo2curves as halo2_curves; use halo2_proofs::{ @@ -32,6 +27,10 @@ use plonk_verifier::{ native::NativeLoader, }, pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, + sdk::{ + self, gen_pk, gen_proof_shplonk, gen_snark_shplonk, halo2::aggregation::AggregationCircuit, + PoseidonTranscript, POSEIDON_SPEC, + }, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::{self, PlonkVerifier}, }; diff --git a/src/sdk.rs b/src/sdk.rs index 520e9f8f..08104f29 100644 --- a/src/sdk.rs +++ b/src/sdk.rs @@ -6,8 +6,8 @@ use crate::{ loader::native::NativeLoader, pcs, poseidon::Spec, - system::halo2::{self, compile, Config}, - util::{hash, transcript::TranscriptWrite}, + system::halo2::{compile, Config}, + util::transcript::TranscriptWrite, verifier::PlonkProof, Protocol, }; @@ -47,7 +47,8 @@ use std::{ path::Path, }; -pub mod aggregation; +pub mod evm; +pub mod halo2; // Poseidon parameters const T: usize = 5; @@ -56,7 +57,7 @@ const R_F: usize = 8; const R_P: usize = 60; pub type PoseidonTranscript = - halo2::transcript::halo2::PoseidonTranscript; + crate::system::halo2::transcript::halo2::PoseidonTranscript; lazy_static! { pub static ref POSEIDON_SPEC: Spec = Spec::new(R_F, R_P); } diff --git a/src/sdk/evm.rs b/src/sdk/evm.rs new file mode 100644 index 00000000..e69de29b diff --git a/src/sdk/halo2.rs b/src/sdk/halo2.rs new file mode 100644 index 00000000..e5e5a4a5 --- /dev/null +++ b/src/sdk/halo2.rs @@ -0,0 +1 @@ +pub mod aggregation; diff --git a/src/sdk/aggregation.rs b/src/sdk/halo2/aggregation.rs similarity index 99% rename from src/sdk/aggregation.rs rename to src/sdk/halo2/aggregation.rs index bd76c116..e1766c4c 100644 --- a/src/sdk/aggregation.rs +++ b/src/sdk/halo2/aggregation.rs @@ -29,7 +29,7 @@ use std::{fs::File, rc::Rc}; const LIMBS: usize = 3; const BITS: usize = 88; -use super::{PoseidonTranscript, Snark, SnarkWitness, POSEIDON_SPEC}; +use crate::sdk::{PoseidonTranscript, Snark, SnarkWitness, POSEIDON_SPEC}; type Svk = KzgSuccinctVerifyingKey; type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; From c5f04860d992b0b18c60dc68712e038c26a76fc0 Mon Sep 17 00:00:00 2001 From: DoHoon Kim <59155248+DoHoonKim8@users.noreply.github.com> Date: Fri, 18 Nov 2022 00:50:24 +0900 Subject: [PATCH 25/28] Update `EvmLoader` to generate Yul code instead of bytecode (#15) * Update `EvmLoader` to generate Yul instead of bytecode * feat: simplify * feat: Add missing end_gas_metering impl Co-authored-by: Han Co-authored-by: Han --- examples/evm-verifier-with-accumulator.rs | 4 +- examples/evm-verifier.rs | 4 +- src/loader/evm.rs | 3 +- src/loader/evm/code.rs | 336 ++--------- src/loader/evm/loader.rs | 689 ++++++++++------------ src/loader/evm/test.rs | 19 +- src/loader/evm/util.rs | 40 +- src/pcs/kzg/decider.rs | 10 +- src/system/halo2/test/kzg/evm.rs | 8 +- src/system/halo2/transcript/evm.rs | 58 +- 10 files changed, 455 insertions(+), 716 deletions(-) diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index dc9e7164..2a8e2b5e 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -20,7 +20,7 @@ use halo2_proofs::{ use itertools::Itertools; use plonk_verifier::{ loader::{ - evm::{encode_calldata, EvmLoader, ExecutorBuilder}, + evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, native::NativeLoader, }, pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, @@ -638,7 +638,7 @@ fn gen_aggregation_evm_verifier( let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - loader.deployment_code() + evm::compile_yul(&loader.yul_code()) } fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { diff --git a/examples/evm-verifier.rs b/examples/evm-verifier.rs index d5ac0ebb..9f66ed86 100644 --- a/examples/evm-verifier.rs +++ b/examples/evm-verifier.rs @@ -21,7 +21,7 @@ use halo2_proofs::{ }; use itertools::Itertools; use plonk_verifier::{ - loader::evm::{encode_calldata, EvmLoader, ExecutorBuilder}, + loader::evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, pcs::kzg::{Gwc19, Kzg}, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::{self, PlonkVerifier}, @@ -239,7 +239,7 @@ fn gen_evm_verifier( let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - loader.deployment_code() + evm::compile_yul(&loader.yul_code()) } fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { diff --git a/src/loader/evm.rs b/src/loader/evm.rs index fa80b97e..263da0e2 100644 --- a/src/loader/evm.rs +++ b/src/loader/evm.rs @@ -7,7 +7,8 @@ mod test; pub use loader::{EcPoint, EvmLoader, Scalar}; pub use util::{ - encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, ExecutorBuilder, MemoryChunk, + compile_yul, encode_calldata, estimate_gas, fe_to_u256, modulus, u256_to_fe, ExecutorBuilder, + MemoryChunk, }; pub use ethereum_types::U256; diff --git a/src/loader/evm/code.rs b/src/loader/evm/code.rs index 80dd5c71..840d1e67 100644 --- a/src/loader/evm/code.rs +++ b/src/loader/evm/code.rs @@ -1,7 +1,3 @@ -use crate::util::Itertools; -use ethereum_types::U256; -use std::{collections::HashMap, iter}; - pub enum Precompiled { BigModExp = 0x05, Bn254Add = 0x6, @@ -10,286 +6,70 @@ pub enum Precompiled { } #[derive(Clone, Debug)] -pub struct Code { - code: Vec, - constants: HashMap, - stack_len: usize, +pub struct YulCode { + // runtime code area + runtime: String, } -impl Code { - pub fn new(constants: impl IntoIterator) -> Self { - let mut code = Self { - code: Vec::new(), - constants: HashMap::new(), - stack_len: 0, - }; - let constants = constants.into_iter().collect_vec(); - for constant in constants.iter() { - code.push(*constant); - } - code.constants = HashMap::from_iter( - constants - .into_iter() - .enumerate() - .map(|(idx, value)| (value, idx)), - ); - code - } - - pub fn deployment(code: Vec) -> Vec { - let code_len = code.len(); - assert_ne!(code_len, 0); - - iter::empty() - .chain([ - PUSH1 + 1, - (code_len >> 8) as u8, - (code_len & 0xff) as u8, - PUSH1, - 14, - PUSH1, - 0, - CODECOPY, - ]) - .chain([ - PUSH1 + 1, - (code_len >> 8) as u8, - (code_len & 0xff) as u8, - PUSH1, - 0, - RETURN, - ]) - .chain(code) - .collect() - } - - pub fn stack_len(&self) -> usize { - self.stack_len - } - - pub fn len(&self) -> usize { - self.code.len() - } - - pub fn is_empty(&self) -> bool { - self.code.is_empty() - } - - pub fn push>(&mut self, value: T) -> &mut Self { - let value = value.into(); - match self.constants.get(&value) { - Some(idx) if (0..16).contains(&(self.stack_len - idx - 1)) => { - self.dup(self.stack_len - idx - 1); - } - _ => { - let mut bytes = vec![0; 32]; - value.to_big_endian(&mut bytes); - let bytes = bytes - .iter() - .position(|byte| *byte != 0) - .map_or(vec![0], |pos| bytes.drain(pos..).collect()); - self.code.push(PUSH1 - 1 + bytes.len() as u8); - self.code.extend(bytes); - self.stack_len += 1; - } +impl YulCode { + pub fn new() -> Self { + YulCode { + runtime: String::new(), } - self - } - - pub fn dup(&mut self, pos: usize) -> &mut Self { - assert!((0..16).contains(&pos)); - self.code.push(DUP1 + pos as u8); - self.stack_len += 1; - self } - pub fn swap(&mut self, pos: usize) -> &mut Self { - assert!((1..17).contains(&pos)); - self.code.push(SWAP1 - 1 + pos as u8); - self + pub fn code(&self, base_modulus: String, scalar_modulus: String) -> String { + format!( + " + object \"plonk_verifier\" {{ + code {{ + function allocate(size) -> ptr {{ + ptr := mload(0x40) + if eq(ptr, 0) {{ ptr := 0x60 }} + mstore(0x40, add(ptr, size)) + }} + let size := datasize(\"Runtime\") + let offset := allocate(size) + datacopy(offset, dataoffset(\"Runtime\"), size) + return(offset, size) + }} + object \"Runtime\" {{ + code {{ + let success:bool := true + let f_p := {base_modulus} + let f_q := {scalar_modulus} + function validate_ec_point(x, y) -> valid:bool {{ + {{ + let x_lt_p:bool := lt(x, {base_modulus}) + let y_lt_p:bool := lt(y, {base_modulus}) + valid := and(x_lt_p, y_lt_p) + }} + {{ + let x_is_zero:bool := eq(x, 0) + let y_is_zero:bool := eq(y, 0) + let x_or_y_is_zero:bool := or(x_is_zero, y_is_zero) + let x_and_y_is_not_zero:bool := not(x_or_y_is_zero) + valid := and(x_and_y_is_not_zero, valid) + }} + {{ + let y_square := mulmod(y, y, {base_modulus}) + let x_square := mulmod(x, x, {base_modulus}) + let x_cube := mulmod(x_square, x, {base_modulus}) + let x_cube_plus_3 := addmod(x_cube, 3, {base_modulus}) + let y_square_eq_x_cube_plus_3:bool := eq(x_cube_plus_3, y_square) + valid := and(y_square_eq_x_cube_plus_3, valid) + }} + }} + {} + }} + }} + }}", + self.runtime + ) } -} -impl From for Vec { - fn from(code: Code) -> Self { - code.code + pub fn runtime_append(&mut self, mut code: String) { + code.push('\n'); + self.runtime.push_str(&code); } } - -macro_rules! impl_opcodes { - ($($method:ident -> ($opcode:ident, $stack_len_diff:expr))*) => { - $( - #[allow(dead_code)] - impl Code { - pub fn $method(&mut self) -> &mut Self { - self.code.push($opcode); - self.stack_len = ((self.stack_len as isize) + $stack_len_diff) as usize; - self - } - } - )* - }; -} - -impl_opcodes!( - stop -> (STOP, 0) - add -> (ADD, -1) - mul -> (MUL, -1) - sub -> (SUB, -1) - div -> (DIV, -1) - sdiv -> (SDIV, -1) - r#mod -> (MOD, -1) - smod -> (SMOD, -1) - addmod -> (ADDMOD, -2) - mulmod -> (MULMOD, -2) - exp -> (EXP, -1) - signextend -> (SIGNEXTEND, -1) - lt -> (LT, -1) - gt -> (GT, -1) - slt -> (SLT, -1) - sgt -> (SGT, -1) - eq -> (EQ, -1) - iszero -> (ISZERO, 0) - and -> (AND, -1) - or -> (OR, -1) - xor -> (XOR, -1) - not -> (NOT, 0) - byte -> (BYTE, -1) - shl -> (SHL, -1) - shr -> (SHR, -1) - sar -> (SAR, -1) - keccak256 -> (SHA3, -1) - address -> (ADDRESS, 1) - balance -> (BALANCE, 0) - origin -> (ORIGIN, 1) - caller -> (CALLER, 1) - callvalue -> (CALLVALUE, 1) - calldataload -> (CALLDATALOAD, 0) - calldatasize -> (CALLDATASIZE, 1) - calldatacopy -> (CALLDATACOPY, -3) - codesize -> (CODESIZE, 1) - codecopy -> (CODECOPY, -3) - gasprice -> (GASPRICE, 1) - extcodesize -> (EXTCODESIZE, 0) - extcodecopy -> (EXTCODECOPY, -4) - returndatasize -> (RETURNDATASIZE, 1) - returndatacopy -> (RETURNDATACOPY, -3) - extcodehash -> (EXTCODEHASH, 0) - blockhash -> (BLOCKHASH, 0) - coinbase -> (COINBASE, 1) - timestamp -> (TIMESTAMP, 1) - number -> (NUMBER, 1) - difficulty -> (DIFFICULTY, 1) - gaslimit -> (GASLIMIT, 1) - chainid -> (CHAINID, 1) - selfbalance -> (SELFBALANCE, 1) - basefee -> (BASEFEE, 1) - pop -> (POP, -1) - mload -> (MLOAD, 0) - mstore -> (MSTORE, -2) - mstore8 -> (MSTORE8, -2) - sload -> (SLOAD, 0) - sstore -> (SSTORE, -2) - jump -> (JUMP, -1) - jumpi -> (JUMPI, -2) - pc -> (PC, 1) - msize -> (MSIZE, 1) - gas -> (GAS, 1) - jumpdest -> (JUMPDEST, 0) - log0 -> (LOG0, -2) - log1 -> (LOG1, -3) - log2 -> (LOG2, -4) - log3 -> (LOG3, -5) - log4 -> (LOG4, -6) - create -> (CREATE, -2) - call -> (CALL, -6) - callcode -> (CALLCODE, -6) - r#return -> (RETURN, -2) - delegatecall -> (DELEGATECALL, -5) - create2 -> (CREATE2, -3) - staticcall -> (STATICCALL, -5) - revert -> (REVERT, -2) - selfdestruct -> (SELFDESTRUCT, -1) -); - -const STOP: u8 = 0x00; -const ADD: u8 = 0x01; -const MUL: u8 = 0x02; -const SUB: u8 = 0x03; -const DIV: u8 = 0x04; -const SDIV: u8 = 0x05; -const MOD: u8 = 0x06; -const SMOD: u8 = 0x07; -const ADDMOD: u8 = 0x08; -const MULMOD: u8 = 0x09; -const EXP: u8 = 0x0A; -const SIGNEXTEND: u8 = 0x0B; -const LT: u8 = 0x10; -const GT: u8 = 0x11; -const SLT: u8 = 0x12; -const SGT: u8 = 0x13; -const EQ: u8 = 0x14; -const ISZERO: u8 = 0x15; -const AND: u8 = 0x16; -const OR: u8 = 0x17; -const XOR: u8 = 0x18; -const NOT: u8 = 0x19; -const BYTE: u8 = 0x1A; -const SHL: u8 = 0x1B; -const SHR: u8 = 0x1C; -const SAR: u8 = 0x1D; -const SHA3: u8 = 0x20; -const ADDRESS: u8 = 0x30; -const BALANCE: u8 = 0x31; -const ORIGIN: u8 = 0x32; -const CALLER: u8 = 0x33; -const CALLVALUE: u8 = 0x34; -const CALLDATALOAD: u8 = 0x35; -const CALLDATASIZE: u8 = 0x36; -const CALLDATACOPY: u8 = 0x37; -const CODESIZE: u8 = 0x38; -const CODECOPY: u8 = 0x39; -const GASPRICE: u8 = 0x3A; -const EXTCODESIZE: u8 = 0x3B; -const EXTCODECOPY: u8 = 0x3C; -const RETURNDATASIZE: u8 = 0x3D; -const RETURNDATACOPY: u8 = 0x3E; -const EXTCODEHASH: u8 = 0x3F; -const BLOCKHASH: u8 = 0x40; -const COINBASE: u8 = 0x41; -const TIMESTAMP: u8 = 0x42; -const NUMBER: u8 = 0x43; -const DIFFICULTY: u8 = 0x44; -const GASLIMIT: u8 = 0x45; -const CHAINID: u8 = 0x46; -const SELFBALANCE: u8 = 0x47; -const BASEFEE: u8 = 0x48; -const POP: u8 = 0x50; -const MLOAD: u8 = 0x51; -const MSTORE: u8 = 0x52; -const MSTORE8: u8 = 0x53; -const SLOAD: u8 = 0x54; -const SSTORE: u8 = 0x55; -const JUMP: u8 = 0x56; -const JUMPI: u8 = 0x57; -const PC: u8 = 0x58; -const MSIZE: u8 = 0x59; -const GAS: u8 = 0x5A; -const JUMPDEST: u8 = 0x5B; -const PUSH1: u8 = 0x60; -const DUP1: u8 = 0x80; -const SWAP1: u8 = 0x90; -const LOG0: u8 = 0xA0; -const LOG1: u8 = 0xA1; -const LOG2: u8 = 0xA2; -const LOG3: u8 = 0xA3; -const LOG4: u8 = 0xA4; -const CREATE: u8 = 0xF0; -const CALL: u8 = 0xF1; -const CALLCODE: u8 = 0xF2; -const RETURN: u8 = 0xF3; -const DELEGATECALL: u8 = 0xF4; -const CREATE2: u8 = 0xF5; -const STATICCALL: u8 = 0xFA; -const REVERT: u8 = 0xFD; -const SELFDESTRUCT: u8 = 0xFF; diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs index d8bda683..db15c8d7 100644 --- a/src/loader/evm/loader.rs +++ b/src/loader/evm/loader.rs @@ -1,9 +1,11 @@ use crate::{ - loader::evm::{ - code::{Code, Precompiled}, - fe_to_u256, modulus, + loader::{ + evm::{ + code::{Precompiled, YulCode}, + fe_to_u256, modulus, u256_to_fe, + }, + EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader, }, - loader::{evm::u256_to_fe, EcPointLoader, LoadedEcPoint, LoadedScalar, Loader, ScalarLoader}, util::{ arithmetic::{CurveAffine, FieldOps, PrimeField}, Itertools, @@ -11,6 +13,7 @@ use crate::{ Error, }; use ethereum_types::{U256, U512}; +use hex; use std::{ cell::RefCell, collections::HashMap, @@ -50,24 +53,29 @@ impl Value { pub struct EvmLoader { base_modulus: U256, scalar_modulus: U256, - code: RefCell, + code: RefCell, ptr: RefCell, cache: RefCell>, #[cfg(test)] gas_metering_ids: RefCell>, } +fn hex_encode_u256(value: &U256) -> String { + let mut bytes = [0; 32]; + value.to_big_endian(&mut bytes); + format!("0x{}", hex::encode(bytes)) +} + impl EvmLoader { pub fn new() -> Rc where - Base: PrimeField, + Base: PrimeField, Scalar: PrimeField, { let base_modulus = modulus::(); let scalar_modulus = modulus::(); - let code = Code::new([1.into(), base_modulus, scalar_modulus - 1, scalar_modulus]) - .push(1) - .to_owned(); + let code = YulCode::new(); + Rc::new(Self { base_modulus, scalar_modulus, @@ -79,22 +87,16 @@ impl EvmLoader { }) } - pub fn deployment_code(self: &Rc) -> Vec { - Code::deployment(self.runtime_code()) - } - - pub fn runtime_code(self: &Rc) -> Vec { - let mut code = self.code.borrow().clone(); - let dst = code.len() + 9; - code.push(dst) - .jumpi() - .push(0) - .push(0) - .revert() - .jumpdest() - .stop() - .to_owned() - .into() + pub fn yul_code(self: &Rc) -> String { + let code = " + if not(success) { revert(0, 0) } + return(0, 0)" + .to_string(); + self.code.borrow_mut().runtime_append(code); + self.code.borrow().code( + hex_encode_u256(&self.base_modulus), + hex_encode_u256(&self.scalar_modulus), + ) } pub fn allocate(self: &Rc, size: usize) -> usize { @@ -103,121 +105,64 @@ impl EvmLoader { ptr } - pub(crate) fn scalar_modulus(&self) -> U256 { - self.scalar_modulus - } - pub(crate) fn ptr(&self) -> usize { *self.ptr.borrow() } - pub(crate) fn code_mut(&self) -> impl DerefMut + '_ { + pub(crate) fn code_mut(&self) -> impl DerefMut + '_ { self.code.borrow_mut() } - pub(crate) fn scalar(self: &Rc, value: Value) -> Scalar { - let value = if matches!( - value, - Value::Constant(_) | Value::Memory(_) | Value::Negated(_) - ) { - value - } else { - let identifier = value.identifier(); - let some_ptr = self.cache.borrow().get(&identifier).cloned(); - let ptr = if let Some(ptr) = some_ptr { - ptr - } else { - self.push(&Scalar { - loader: self.clone(), - value, - }); - let ptr = self.allocate(0x20); - self.code.borrow_mut().push(ptr).mstore(); - self.cache.borrow_mut().insert(identifier, ptr); - ptr - }; - Value::Memory(ptr) - }; - Scalar { - loader: self.clone(), - value, - } - } - - fn ec_point(self: &Rc, value: Value<(U256, U256)>) -> EcPoint { - EcPoint { - loader: self.clone(), - value, - } - } - - fn push(self: &Rc, scalar: &Scalar) { + fn push(self: &Rc, scalar: &Scalar) -> String { match scalar.value.clone() { Value::Constant(constant) => { - self.code.borrow_mut().push(constant); + format!("{constant}") } Value::Memory(ptr) => { - self.code.borrow_mut().push(ptr).mload(); + format!("mload({ptr:#x})") } Value::Negated(value) => { - self.push(&self.scalar(*value)); - self.code.borrow_mut().push(self.scalar_modulus).sub(); + let v = self.push(&self.scalar(*value)); + format!("sub(f_q, {v})") } Value::Sum(lhs, rhs) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(*lhs)); - self.push(&self.scalar(*rhs)); - self.code.borrow_mut().addmod(); + let lhs = self.push(&self.scalar(*lhs)); + let rhs = self.push(&self.scalar(*rhs)); + format!("addmod({lhs}, {rhs}, f_q)") } Value::Product(lhs, rhs) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(*lhs)); - self.push(&self.scalar(*rhs)); - self.code.borrow_mut().mulmod(); + let lhs = self.push(&self.scalar(*lhs)); + let rhs = self.push(&self.scalar(*rhs)); + format!("mulmod({lhs}, {rhs}, f_q)") } } } pub fn calldataload_scalar(self: &Rc, offset: usize) -> Scalar { let ptr = self.allocate(0x20); - self.code - .borrow_mut() - .push(self.scalar_modulus) - .push(offset) - .calldataload() - .r#mod() - .push(ptr) - .mstore(); + let code = format!("mstore({ptr:#x}, mod(calldataload({offset:#x}), f_q))"); + self.code.borrow_mut().runtime_append(code); self.scalar(Value::Memory(ptr)) } pub fn calldataload_ec_point(self: &Rc, offset: usize) -> EcPoint { - let ptr = self.allocate(0x40); - self.code - .borrow_mut() - // [..., success] - .push(offset) - // [..., success, x_cd_ptr] - .calldataload() - // [..., success, x] - .dup(0) - // [..., success, x, x] - .push(ptr) - // [..., success, x, x, x_ptr] - .mstore() - // [..., success, x] - .push(offset + 0x20) - // [..., success, x, y_cd_ptr] - .calldataload() - // [..., success, x, y] - .dup(0) - // [..., success, x, y, y] - .push(ptr + 0x20) - // [..., success, x, y, y, y_ptr] - .mstore(); - // [..., success, x, y] - self.validate_ec_point(); - self.ec_point(Value::Memory(ptr)) + let x_ptr = self.allocate(0x40); + let y_ptr = x_ptr + 0x20; + let x_cd_ptr = offset; + let y_cd_ptr = offset + 0x20; + let validate_code = self.validate_ec_point(); + let code = format!( + " + {{ + let x := calldataload({x_cd_ptr:#x}) + mstore({x_ptr:#x}, x) + let y := calldataload({y_cd_ptr:#x}) + mstore({y_ptr:#x}, y) + {validate_code} + }}" + ); + self.code.borrow_mut().runtime_append(code); + self.ec_point(Value::Memory(x_ptr)) } pub fn ec_point_from_limbs( @@ -226,124 +171,94 @@ impl EvmLoader { y_limbs: [&Scalar; LIMBS], ) -> EcPoint { let ptr = self.allocate(0x40); - for (ptr, limbs) in [(ptr, x_limbs), (ptr + 0x20, y_limbs)] { - for (idx, limb) in limbs.into_iter().enumerate() { - self.push(limb); - // [..., success, acc] - if idx > 0 { - self.code - .borrow_mut() - .push(idx * BITS) - // [..., success, acc, limb_i, shift] - .shl() - // [..., success, acc, limb_i << shift] - .add(); - // [..., success, acc] - } + let mut code = String::new(); + for (idx, limb) in x_limbs.iter().enumerate() { + let limb_i = self.push(limb); + let shift = idx * BITS; + if idx == 0 { + code.push_str(format!("let x := {limb_i}\n").as_str()); + } else { + code.push_str(format!("x := add(x, shl({shift}, {limb_i}))\n").as_str()); + } + } + let x_ptr = ptr; + code.push_str(format!("mstore({x_ptr}, x)\n").as_str()); + for (idx, limb) in y_limbs.iter().enumerate() { + let limb_i = self.push(limb); + let shift = idx * BITS; + if idx == 0 { + code.push_str(format!("let y := {limb_i}\n").as_str()); + } else { + code.push_str(format!("y := add(y, shl({shift}, {limb_i}))\n").as_str()); } - self.code - .borrow_mut() - // [..., success, coordinate] - .dup(0) - // [..., success, coordinate, coordinate] - .push(ptr) - // [..., success, coordinate, coordinate, ptr] - .mstore(); - // [..., success, coordinate] } - // [..., success, x, y] - self.validate_ec_point(); + let y_ptr = ptr + 0x20; + code.push_str(format!("mstore({y_ptr}, y)\n").as_str()); + let validate_code = self.validate_ec_point(); + let code = format!( + "{{ + {code} + {validate_code} + }}" + ); + self.code.borrow_mut().runtime_append(code); self.ec_point(Value::Memory(ptr)) } - fn validate_ec_point(self: &Rc) { - self.code - .borrow_mut() - // [..., success, x, y] - .push(self.base_modulus) - // [..., success, x, y, p] - .dup(2) - // [..., success, x, y, p, x] - .lt() - // [..., success, x, y, x_lt_p] - .push(self.base_modulus) - // [..., success, x, y, x_lt_p, p] - .dup(2) - // [..., success, x, y, x_lt_p, p, y] - .lt() - // [..., success, x, y, x_lt_p, y_lt_p] - .and() - // [..., success, x, y, valid] - .dup(2) - // [..., success, x, y, valid, x] - .iszero() - // [..., success, x, y, valid, x_is_zero] - .dup(2) - // [..., success, x, y, valid, x_is_zero, y] - .iszero() - // [..., success, x, y, valid, x_is_zero, y_is_zero] - .or() - // [..., success, x, y, valid, x_or_y_is_zero] - .not() - // [..., success, x, y, valid, x_and_y_is_not_zero] - .and() - // [..., success, x, y, valid] - .push(self.base_modulus) - // [..., success, x, y, valid, p] - .dup(2) - // [..., success, x, y, valid, p, y] - .dup(0) - // [..., success, x, y, valid, p, y, y] - .mulmod() - // [..., success, x, y, valid, y_square] - .push(self.base_modulus) - // [..., success, x, y, valid, y_square, p] - .push(3) - // [..., success, x, y, valid, y_square, p, 3] - .push(self.base_modulus) - // [..., success, x, y, valid, y_square, p, 3, p] - .dup(6) - // [..., success, x, y, valid, y_square, p, 3, p, x] - .push(self.base_modulus) - // [..., success, x, y, valid, y_square, p, 3, p, x, p] - .dup(1) - // [..., success, x, y, valid, y_square, p, 3, p, x, p, x] - .dup(0) - // [..., success, x, y, valid, y_square, p, 3, p, x, p, x, x] - .mulmod() - // [..., success, x, y, valid, y_square, p, 3, p, x, x_square] - .mulmod() - // [..., success, x, y, valid, y_square, p, 3, x_cube] - .addmod() - // [..., success, x, y, valid, y_square, x_cube_plus_3] - .eq() - // [..., success, x, y, valid, y_square_eq_x_cube_plus_3] - .and() - // [..., success, x, y, valid] - .swap(2) - // [..., success, valid, y, x] - .pop() - // [..., success, valid, y] - .pop() - // [..., success, valid] - .and(); + fn validate_ec_point(self: &Rc) -> String { + "success := and(validate_ec_point(x, y), success)".to_string() + } + + pub(crate) fn scalar(self: &Rc, value: Value) -> Scalar { + let value = if matches!( + value, + Value::Constant(_) | Value::Memory(_) | Value::Negated(_) + ) { + value + } else { + let identifier = value.identifier(); + let some_ptr = self.cache.borrow().get(&identifier).cloned(); + let ptr = if let Some(ptr) = some_ptr { + ptr + } else { + let v = self.push(&Scalar { + loader: self.clone(), + value, + }); + let ptr = self.allocate(0x20); + self.code + .borrow_mut() + .runtime_append(format!("mstore({ptr:#x}, {v})")); + self.cache.borrow_mut().insert(identifier, ptr); + ptr + }; + Value::Memory(ptr) + }; + Scalar { + loader: self.clone(), + value, + } + } + + fn ec_point(self: &Rc, value: Value<(U256, U256)>) -> EcPoint { + EcPoint { + loader: self.clone(), + value, + } } pub fn keccak256(self: &Rc, ptr: usize, len: usize) -> usize { let hash_ptr = self.allocate(0x20); - self.code - .borrow_mut() - .push(len) - .push(ptr) - .keccak256() - .push(hash_ptr) - .mstore(); + let code = format!("mstore({hash_ptr:#x}, keccak256({ptr:#x}, {len}))"); + self.code.borrow_mut().runtime_append(code); hash_ptr } pub fn copy_scalar(self: &Rc, scalar: &Scalar, ptr: usize) { - self.push(scalar); - self.code.borrow_mut().push(ptr).mstore(); + let scalar = self.push(scalar); + self.code + .borrow_mut() + .runtime_append(format!("mstore({ptr:#x}, {scalar})")); } pub fn dup_scalar(self: &Rc, scalar: &Scalar) -> Scalar { @@ -356,26 +271,26 @@ impl EvmLoader { let ptr = self.allocate(0x40); match value.value { Value::Constant((x, y)) => { - self.code - .borrow_mut() - .push(x) - .push(ptr) - .mstore() - .push(y) - .push(ptr + 0x20) - .mstore(); + let x_ptr = ptr; + let y_ptr = ptr + 0x20; + let x = hex_encode_u256(&x); + let y = hex_encode_u256(&y); + let code = format!( + "mstore({x_ptr:#x}, {x}) + mstore({y_ptr:#x}, {y})" + ); + self.code.borrow_mut().runtime_append(code); } Value::Memory(src_ptr) => { - self.code - .borrow_mut() - .push(src_ptr) - .mload() - .push(ptr) - .mstore() - .push(src_ptr + 0x20) - .mload() - .push(ptr + 0x20) - .mstore(); + let x_ptr = ptr; + let y_ptr = ptr + 0x20; + let src_x = src_ptr; + let src_y = src_ptr + 0x20; + let code = format!( + "mstore({x_ptr:#x}, mload({src_x:#x})) + mstore({y_ptr:#x}, mload({src_y:#x}))" + ); + self.code.borrow_mut().runtime_append(code); } Value::Negated(_) | Value::Sum(_, _) | Value::Product(_, _) => { unreachable!() @@ -391,16 +306,9 @@ impl EvmLoader { Precompiled::Bn254ScalarMul => (0x60, 0x40), Precompiled::Bn254Pairing => (0x180, 0x20), }; - self.code - .borrow_mut() - .push(rd_len) - .push(rd_ptr) - .push(cd_len) - .push(cd_ptr) - .push(precompile as usize) - .gas() - .staticcall() - .and(); + let a = precompile as usize; + let code = format!("success := and(eq(staticcall(gas(), {a:#x}, {cd_ptr:#x}, {cd_len:#x}, {rd_ptr:#x}, {rd_len:#x}), 1), success)"); + self.code.borrow_mut().runtime_append(code); } fn invert(self: &Rc, scalar: &Scalar) -> Scalar { @@ -441,38 +349,41 @@ impl EvmLoader { ) { let rd_ptr = self.dup_ec_point(lhs).ptr(); self.allocate(0x80); - self.code - .borrow_mut() - .push(g2.0) - .push(rd_ptr + 0x40) - .mstore() - .push(g2.1) - .push(rd_ptr + 0x60) - .mstore() - .push(g2.2) - .push(rd_ptr + 0x80) - .mstore() - .push(g2.3) - .push(rd_ptr + 0xa0) - .mstore(); + let g2_0 = hex_encode_u256(&g2.0); + let g2_0_ptr = rd_ptr + 0x40; + let g2_1 = hex_encode_u256(&g2.1); + let g2_1_ptr = rd_ptr + 0x60; + let g2_2 = hex_encode_u256(&g2.2); + let g2_2_ptr = rd_ptr + 0x80; + let g2_3 = hex_encode_u256(&g2.3); + let g2_3_ptr = rd_ptr + 0xa0; + let code = format!( + "mstore({g2_0_ptr:#x}, {g2_0}) + mstore({g2_1_ptr:#x}, {g2_1}) + mstore({g2_2_ptr:#x}, {g2_2}) + mstore({g2_3_ptr:#x}, {g2_3})" + ); + self.code.borrow_mut().runtime_append(code); self.dup_ec_point(rhs); self.allocate(0x80); - self.code - .borrow_mut() - .push(minus_s_g2.0) - .push(rd_ptr + 0x100) - .mstore() - .push(minus_s_g2.1) - .push(rd_ptr + 0x120) - .mstore() - .push(minus_s_g2.2) - .push(rd_ptr + 0x140) - .mstore() - .push(minus_s_g2.3) - .push(rd_ptr + 0x160) - .mstore(); + let minus_s_g2_0 = hex_encode_u256(&minus_s_g2.0); + let minus_s_g2_0_ptr = rd_ptr + 0x100; + let minus_s_g2_1 = hex_encode_u256(&minus_s_g2.1); + let minus_s_g2_1_ptr = rd_ptr + 0x120; + let minus_s_g2_2 = hex_encode_u256(&minus_s_g2.2); + let minus_s_g2_2_ptr = rd_ptr + 0x140; + let minus_s_g2_3 = hex_encode_u256(&minus_s_g2.3); + let minus_s_g2_3_ptr = rd_ptr + 0x160; + let code = format!( + "mstore({minus_s_g2_0_ptr:#x}, {minus_s_g2_0}) + mstore({minus_s_g2_1_ptr:#x}, {minus_s_g2_1}) + mstore({minus_s_g2_2_ptr:#x}, {minus_s_g2_2}) + mstore({minus_s_g2_3_ptr:#x}, {minus_s_g2_3})" + ); + self.code.borrow_mut().runtime_append(code); self.staticcall(Precompiled::Bn254Pairing, rd_ptr, rd_ptr); - self.code.borrow_mut().push(rd_ptr).mload().and(); + let code = format!("success := and(eq(mload({rd_ptr:#x}), 1), success)"); + self.code.borrow_mut().runtime_append(code); } fn add(self: &Rc, lhs: &Scalar, rhs: &Scalar) -> Scalar { @@ -525,21 +436,16 @@ impl EvmLoader { self.gas_metering_ids .borrow_mut() .push(identifier.to_string()); - self.code.borrow_mut().gas().swap(1); + let code = format!("let {identifier} := gas()"); + self.code.borrow_mut().runtime_append(code); } fn end_gas_metering(self: &Rc) { - self.code - .borrow_mut() - .swap(1) - .push(9) - .gas() - .swap(2) - .sub() - .sub() - .push(0) - .push(0) - .log1(); + let code = format!( + "log1(0, 0, sub({}, gas()))", + self.gas_metering_ids.borrow().last().unwrap() + ); + self.code.borrow_mut().runtime_append(code); } pub fn print_gas_metering(self: &Rc, costs: Vec) { @@ -745,7 +651,7 @@ impl PartialEq for Scalar { impl> LoadedScalar for Scalar { type Loader = Rc; - fn loader(&self) -> &Rc { + fn loader(&self) -> &Self::Loader { &self.loader } } @@ -753,7 +659,7 @@ impl> LoadedScalar for Scalar { impl EcPointLoader for Rc where C: CurveAffine, - C::ScalarExt: PrimeField, + C::Scalar: PrimeField, { type LoadedEcPoint = EcPoint; @@ -802,48 +708,39 @@ impl> ScalarLoader for Rc { let push_addend = |(coeff, value): &(F, &Scalar)| { assert_ne!(*coeff, F::zero()); match (*coeff == F::one(), &value.value) { - (true, _) => { - self.push(value); - } - (false, Value::Constant(value)) => { - self.push(&self.scalar(Value::Constant(fe_to_u256( - *coeff * u256_to_fe::(*value), - )))); - } + (true, _) => self.push(value), + (false, Value::Constant(value)) => self.push(&self.scalar(Value::Constant( + fe_to_u256(*coeff * u256_to_fe::(*value)), + ))), (false, _) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); - self.push(value); - self.code.borrow_mut().mulmod(); + let value = self.push(value); + let coeff = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + format!("mulmod({value}, {coeff}, f_q)") } } }; let mut values = values.iter(); - if constant == F::zero() { - push_addend(values.next().unwrap()); + let initial_value = if constant == F::zero() { + push_addend(values.next().unwrap()) } else { - self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); - } - - let chunk_size = 16 - self.code.borrow().stack_len(); - for values in &values.chunks(chunk_size) { - let values = values.into_iter().collect_vec(); - - self.code.borrow_mut().push(self.scalar_modulus); - for _ in 1..chunk_size.min(values.len()) { - self.code.borrow_mut().dup(0); - } - self.code.borrow_mut().swap(chunk_size.min(values.len())); + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))) + }; - for value in values { - push_addend(value); - self.code.borrow_mut().addmod(); - } + let mut code = format!("let result := {initial_value}\n"); + for value in values { + let v = push_addend(value); + let addend = format!("result := addmod({v}, result, f_q)\n"); + code.push_str(addend.as_str()); } let ptr = self.allocate(0x20); - self.code.borrow_mut().push(ptr).mstore(); + code.push_str(format!("mstore({ptr}, result)").as_str()); + self.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); self.scalar(Value::Memory(ptr)) } @@ -863,63 +760,64 @@ impl> ScalarLoader for Rc { (_, Value::Constant(lhs), Value::Constant(rhs)) => { self.push(&self.scalar(Value::Constant(fe_to_u256( *coeff * u256_to_fe::(*lhs) * u256_to_fe::(*rhs), - )))); + )))) } (_, value @ Value::Memory(_), Value::Constant(constant)) | (_, Value::Constant(constant), value @ Value::Memory(_)) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(&self.scalar(Value::Constant(fe_to_u256( + let v1 = self.push(&self.scalar(value.clone())); + let v2 = self.push(&self.scalar(Value::Constant(fe_to_u256( *coeff * u256_to_fe::(*constant), )))); - self.push(&self.scalar(value.clone())); - self.code.borrow_mut().mulmod(); + format!("mulmod({v1}, {v2}, f_q)") } (true, _, _) => { - self.code.borrow_mut().push(self.scalar_modulus); - self.push(lhs); - self.push(rhs); - self.code.borrow_mut().mulmod(); + let rhs = self.push(rhs); + let lhs = self.push(lhs); + format!("mulmod({rhs}, {lhs}, f_q)") } (false, _, _) => { - self.code.borrow_mut().push(self.scalar_modulus).dup(0); - self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); - self.push(lhs); - self.code.borrow_mut().mulmod(); - self.push(rhs); - self.code.borrow_mut().mulmod(); + let rhs = self.push(rhs); + let lhs = self.push(lhs); + let value = self.push(&self.scalar(Value::Constant(fe_to_u256(*coeff)))); + format!("mulmod({rhs}, mulmod({lhs}, {value}, f_q), f_q)") } } }; let mut values = values.iter(); - if constant == F::zero() { - push_addend(values.next().unwrap()); + let initial_value = if constant == F::zero() { + push_addend(values.next().unwrap()) } else { - self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))); - } - - let chunk_size = 16 - self.code.borrow().stack_len(); - for values in &values.chunks(chunk_size) { - let values = values.into_iter().collect_vec(); - - self.code.borrow_mut().push(self.scalar_modulus); - for _ in 1..chunk_size.min(values.len()) { - self.code.borrow_mut().dup(0); - } - self.code.borrow_mut().swap(chunk_size.min(values.len())); + self.push(&self.scalar(Value::Constant(fe_to_u256(constant)))) + }; - for value in values { - push_addend(value); - self.code.borrow_mut().addmod(); - } + let mut code = format!("let result := {initial_value}\n"); + for value in values { + let v = push_addend(value); + let addend = format!("result := addmod({v}, result, f_q)\n"); + code.push_str(addend.as_str()); } let ptr = self.allocate(0x20); - self.code.borrow_mut().push(ptr).mstore(); + code.push_str(format!("mstore({ptr}, result)").as_str()); + self.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); self.scalar(Value::Memory(ptr)) } + // batch_invert algorithm + // n := values.len() - 1 + // input : values[0], ..., values[n] + // output : values[0]^{-1}, ..., values[n]^{-1} + // 1. products[i] <- values[0] * ... * values[i], i = 1, ..., n + // 2. inv <- (products[n])^{-1} + // 3. v_n <- values[n] + // 4. values[n] <- products[n - 1] * inv (values[n]^{-1}) + // 5. inv <- v_n * inv fn batch_invert<'a>(values: impl IntoIterator) { let values = values.into_iter().collect_vec(); let loader = &values.first().unwrap().loader; @@ -931,29 +829,35 @@ impl> ScalarLoader for Rc { ) .collect_vec(); - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); + let initial_value = loader.push(products.first().unwrap()); + let mut code = format!("let prod := {initial_value}\n"); + for (_, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { + let v = loader.push(value); + let ptr = product.ptr(); + code.push_str( + format!( + " + prod := mulmod({v}, prod, f_q) + mstore({ptr:#x}, prod) + " + ) + .as_str(), + ); } - - loader.push(products.first().unwrap()); - for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { - loader.push(value); - loader.code.borrow_mut().mulmod(); - if idx < values.len() - 2 { - loader.code.borrow_mut().dup(0); - } - loader.code.borrow_mut().push(product.ptr()).mstore(); - } - - let inv = loader.invert(products.last().unwrap()); - - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); - } - - loader.push(&inv); + loader.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); + + let inv = loader.push(&loader.invert(products.last().unwrap())); + + let mut code = format!( + " + let inv := {inv} + let v + " + ); for (value, product) in values.iter().rev().zip( products .iter() @@ -963,22 +867,29 @@ impl> ScalarLoader for Rc { .chain(iter::once(None)), ) { if let Some(product) = product { - loader.push(value); - loader - .code - .borrow_mut() - .dup(2) - .dup(2) - .push(product.ptr()) - .mload() - .mulmod() - .push(value.ptr()) - .mstore() - .mulmod(); + let val_ptr = value.ptr(); + let prod_ptr = product.ptr(); + let v = loader.push(value); + code.push_str( + format!( + " + v := {v} + mstore({val_ptr}, mulmod(mload({prod_ptr:#x}), inv, f_q)) + inv := mulmod(v, inv, f_q) + " + ) + .as_str(), + ); } else { - loader.code.borrow_mut().push(value.ptr()).mstore(); + let ptr = value.ptr(); + code.push_str(format!("mstore({ptr:#x}, inv)\n").as_str()); } } + loader.code.borrow_mut().runtime_append(format!( + "{{ + {code} + }}" + )); } } diff --git a/src/loader/evm/test.rs b/src/loader/evm/test.rs index f81c5f54..e6f3703e 100644 --- a/src/loader/evm/test.rs +++ b/src/loader/evm/test.rs @@ -3,7 +3,6 @@ use crate::{ util::Itertools, }; use ethereum_types::{Address, U256}; -use revm::{AccountInfo, Bytecode}; use std::env::var_os; mod tui; @@ -15,28 +14,26 @@ fn debug() -> bool { ) } -pub fn execute(code: Vec, calldata: Vec) -> (bool, u64, Vec) { +pub fn execute(deployment_code: Vec, calldata: Vec) -> (bool, u64, Vec) { assert!( - code.len() <= 0x6000, + deployment_code.len() <= 0x6000, "Contract size {} exceeds the limit 24576", - code.len() + deployment_code.len() ); let debug = debug(); let caller = Address::from_low_u64_be(0xfe); - let callee = Address::from_low_u64_be(0xff); let mut evm = ExecutorBuilder::default() .with_gas_limit(u64::MAX.into()) .set_debugger(debug) .build(); - evm.db_mut().insert_account_info( - callee, - AccountInfo::new(0.into(), 1, Bytecode::new_raw(code.into())), - ); - - let result = evm.call_raw(caller, callee, calldata.into(), 0.into()); + let contract = evm + .deploy(caller, deployment_code.into(), 0.into()) + .address + .unwrap(); + let result = evm.call_raw(caller, contract, calldata.into(), 0.into()); let costs = result .logs diff --git a/src/loader/evm/util.rs b/src/loader/evm/util.rs index 0a772513..a7df5209 100644 --- a/src/loader/evm/util.rs +++ b/src/loader/evm/util.rs @@ -3,7 +3,11 @@ use crate::{ util::{arithmetic::PrimeField, Itertools}, }; use ethereum_types::U256; -use std::iter; +use std::{ + io::Write, + iter, + process::{Command, Stdio}, +}; pub(crate) mod executor; @@ -94,3 +98,37 @@ pub fn estimate_gas(cost: Cost) -> usize { intrinsic_cost + calldata_cost + ec_operation_cost } + +pub fn compile_yul(code: &str) -> Vec { + let mut cmd = Command::new("solc") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .arg("--bin") + .arg("--yul") + .arg("-") + .spawn() + .unwrap(); + cmd.stdin + .take() + .unwrap() + .write_all(code.as_bytes()) + .unwrap(); + let output = cmd.wait_with_output().unwrap().stdout; + let binary = *split_by_ascii_whitespace(&output).last().unwrap(); + hex::decode(binary).unwrap() +} + +fn split_by_ascii_whitespace(bytes: &[u8]) -> Vec<&[u8]> { + let mut split = Vec::new(); + let mut start = None; + for (idx, byte) in bytes.iter().enumerate() { + if byte.is_ascii_whitespace() { + if let Some(start) = start.take() { + split.push(&bytes[start..idx]); + } + } else if start.is_none() { + start = Some(idx); + } + } + split +} diff --git a/src/pcs/kzg/decider.rs b/src/pcs/kzg/decider.rs index 0cb8b720..baabda6c 100644 --- a/src/pcs/kzg/decider.rs +++ b/src/pcs/kzg/decider.rs @@ -140,14 +140,8 @@ mod evm { let hash_ptr = loader.keccak256(lhs[0].ptr(), lhs.len() * 0x80); let challenge_ptr = loader.allocate(0x20); - loader - .code_mut() - .push(loader.scalar_modulus()) - .push(hash_ptr) - .mload() - .r#mod() - .push(challenge_ptr) - .mstore(); + let code = format!("mstore({challenge_ptr}, mod(mload({hash_ptr}), f_q))"); + loader.code_mut().runtime_append(code); let challenge = loader.scalar(Value::Memory(challenge_ptr)); let powers_of_challenge = LoadedScalar::::powers(&challenge, lhs.len()); diff --git a/src/system/halo2/test/kzg/evm.rs b/src/system/halo2/test/kzg/evm.rs index 91308410..28357415 100644 --- a/src/system/halo2/test/kzg/evm.rs +++ b/src/system/halo2/test/kzg/evm.rs @@ -25,7 +25,7 @@ macro_rules! halo2_kzg_evm_verify { use halo2_proofs::poly::commitment::ParamsProver; use std::rc::Rc; use $crate::{ - loader::evm::{encode_calldata, execute, EvmLoader}, + loader::evm::{compile_yul, encode_calldata, execute, EvmLoader}, system::halo2::{ test::kzg::{BITS, LIMBS}, transcript::evm::EvmTranscript, @@ -35,7 +35,7 @@ macro_rules! halo2_kzg_evm_verify { }; let loader = EvmLoader::new::(); - let runtime_code = { + let deployment_code = { let svk = $params.get_g()[0].into(); let dk = ($params.g2(), $params.s_g2()).into(); let protocol = $protocol.loaded(&loader); @@ -46,11 +46,11 @@ macro_rules! halo2_kzg_evm_verify { .unwrap(); <$plonk_verifier>::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - loader.runtime_code() + compile_yul(&loader.yul_code()) }; let (accept, total_cost, costs) = - execute(runtime_code, encode_calldata($instances, &$proof)); + execute(deployment_code, encode_calldata($instances, &$proof)); loader.print_gas_metering(costs); println!("Total gas cost: {}", total_cost); diff --git a/src/system/halo2/transcript/evm.rs b/src/system/halo2/transcript/evm.rs index 569302b2..32d60b8f 100644 --- a/src/system/halo2/transcript/evm.rs +++ b/src/system/halo2/transcript/evm.rs @@ -1,4 +1,3 @@ -use crate::halo2_proofs; use crate::{ loader::{ evm::{loader::Value, u256_to_fe, EcPoint, EvmLoader, MemoryChunk, Scalar}, @@ -38,7 +37,12 @@ where assert_eq!(ptr, 0); let mut buf = MemoryChunk::new(ptr); buf.extend(0x20); - Self { loader: loader.clone(), stream: 0, buf, _marker: PhantomData } + Self { + loader: loader.clone(), + stream: 0, + buf, + _marker: PhantomData, + } } pub fn load_instances(&mut self, num_instance: Vec) -> Vec> { @@ -69,7 +73,9 @@ where fn squeeze_challenge(&mut self) -> Scalar { let len = if self.buf.len() == 0x20 { assert_eq!(self.loader.ptr(), self.buf.end()); - self.loader.code_mut().push(1).push(self.buf.end()).mstore8(); + let buf_end = self.buf.end(); + let code = format!("mstore8({buf_end}, 1)"); + self.loader.code_mut().runtime_append(code); 0x21 } else { self.buf.len() @@ -78,17 +84,14 @@ where let challenge_ptr = self.loader.allocate(0x20); let dup_hash_ptr = self.loader.allocate(0x20); - self.loader - .code_mut() - .push(hash_ptr) - .mload() - .push(self.loader.scalar_modulus()) - .dup(1) - .r#mod() - .push(challenge_ptr) - .mstore() - .push(dup_hash_ptr) - .mstore(); + let code = format!( + "{{ + let hash := mload({hash_ptr:#x}) + mstore({challenge_ptr:#x}, mod(hash, f_q)) + mstore({dup_hash_ptr:#x}, hash) + }}" + ); + self.loader.code_mut().runtime_append(code); self.buf.reset(dup_hash_ptr); self.buf.extend(0x20); @@ -146,7 +149,12 @@ where C: CurveAffine, { pub fn new(stream: S) -> Self { - Self { loader: NativeLoader, stream, buf: Vec::new(), _marker: PhantomData } + Self { + loader: NativeLoader, + stream, + buf: Vec::new(), + _marker: PhantomData, + } } } @@ -164,7 +172,11 @@ where .buf .iter() .cloned() - .chain(if self.buf.len() == 0x20 { Some(1) } else { None }) + .chain(if self.buf.len() == 0x20 { + Some(1) + } else { + None + }) .collect_vec(); let hash: [u8; 32] = Keccak256::digest(data).into(); self.buf = hash.to_vec(); @@ -181,7 +193,8 @@ where })?; [coordinates.x(), coordinates.y()].map(|coordinate| { - self.buf.extend(coordinate.to_repr().as_ref().iter().rev().cloned()); + self.buf + .extend(coordinate.to_repr().as_ref().iter().rev().cloned()); }); Ok(()) @@ -207,7 +220,10 @@ where .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; data.reverse(); let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { - Error::Transcript(io::ErrorKind::Other, "Invalid scalar encoding in proof".to_string()) + Error::Transcript( + io::ErrorKind::Other, + "Invalid scalar encoding in proof".to_string(), + ) })?; self.common_scalar(&scalar)?; Ok(scalar) @@ -223,8 +239,10 @@ where } let x = Option::from(::from_repr(x)); let y = Option::from(::from_repr(y)); - let ec_point = - x.zip(y).and_then(|(x, y)| Option::from(C::from_xy(x, y))).ok_or_else(|| { + let ec_point = x + .zip(y) + .and_then(|(x, y)| Option::from(C::from_xy(x, y))) + .ok_or_else(|| { Error::Transcript( io::ErrorKind::Other, "Invalid elliptic curve point encoding in proof".to_string(), From 888a9168e66374bb7c97f643ea4e452c2f6133c0 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Sat, 3 Dec 2022 20:44:46 -0500 Subject: [PATCH 26/28] feat: add `evm` module to `sdk` with evm proof and verifier functions * add bench for zkevm EVM circuit, currently uses my local repo because official repo does not compile... use scroll-dev-1115 circuit implementation where power of randomness is generated internally and not from public inputs (far too many public inputs for aggregation) --- Cargo.toml | 47 ++-- benches/standard_plonk.rs | 44 ++- benches/zkevm.rs | 140 ++++++++++ configs/bench_zkevm.config | 1 + examples/evm-verifier-with-accumulator.rs | 101 +++---- examples/recursion.rs | 4 +- src/loader/halo2/shim.rs | 2 +- src/pcs/kzg/accumulator.rs | 6 +- src/sdk.rs | 310 ++------------------- src/sdk/evm.rs | 191 +++++++++++++ src/sdk/halo2.rs | 323 ++++++++++++++++++++++ src/sdk/halo2/aggregation.rs | 69 +++-- src/system.rs | 1 - src/system/halo2.rs | 1 + src/system/halo2/test/circuit/standard.rs | 8 +- src/system/halo2/transcript/evm.rs | 35 +-- src/system/halo2/transcript/halo2.rs | 6 +- 17 files changed, 839 insertions(+), 450 deletions(-) create mode 100644 benches/zkevm.rs create mode 100644 configs/bench_zkevm.config diff --git a/Cargo.toml b/Cargo.toml index 3fecf432..d73c60f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,8 +34,10 @@ rlp = { version = "0.5", default-features = false, features = ["std"], optional halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_ecc", default-features = false, optional = true } # zkevm benchmarks -zkevm_circuit_benchmarks = {git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", package = "circuit-benchmarks", features = ["benches"], optional = true } -zkevm_circuits = {git = "https://github.com/privacy-scaling-explorations/zkevm-circuits.git", package = "zkevm-circuits", optional = true } +zkevm-circuits = { path = "../zkevm-circuits/zkevm-circuits", features = ["test"], optional = true } +bus-mapping = { path = "../zkevm-circuits/bus-mapping", optional = true } +eth-types = { path = "../zkevm-circuits/eth-types", optional = true } +mock = { path = "../zkevm-circuits/mock", optional = true } [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } @@ -43,16 +45,12 @@ paste = "1.0.7" pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" - -# system_halo2 -halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_ecc", default-features = false } - # loader_evm crossterm = { version = "0.25" } tui = { version = "0.19", default-features = false, features = ["crossterm"] } [features] -default = ["loader_evm", "loader_halo2", "system_halo2", "halo2-axiom", "halo2_ecc/jemallocator"] +default = ["loader_evm", "loader_halo2", "zkevm", "halo2-pse", "halo2_ecc?/jemallocator"] display = ["halo2_base/display", "halo2_ecc?/display", "ark-std"] # EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo halo2-pse = ["halo2_base/halo2-pse", "halo2_ecc?/halo2-pse"] @@ -62,39 +60,45 @@ parallel = ["dep:rayon"] loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] loader_halo2 = ["halo2_ecc"] -system_halo2 = [] -zkevm = ["dep:zkevm_circuit_benchmarks", "dep:zkevm_circuits"] +zkevm = ["dep:zkevm-circuits", "dep:bus-mapping", "dep:mock", "dep:eth-types"] [[example]] name = "evm-verifier" -required-features = ["loader_evm", "system_halo2"] +required-features = ["loader_evm"] [[example]] name = "evm-verifier-with-accumulator" -required-features = ["loader_halo2", "system_halo2", "loader_evm"] +required-features = ["loader_halo2", "loader_evm"] [[example]] name = "recursion" -required-features = ["loader_halo2", "system_halo2"] +required-features = ["loader_halo2"] [[bench]] name = "standard_plonk" -required-features = ["loader_halo2", "system_halo2"] +required-features = ["loader_halo2"] +harness = false + +[[bench]] +name = "zkevm" +required-features = ["loader_halo2", "zkevm", "halo2-pse"] harness = false [profile.dev] opt-level = 3 -# Local "release" mode, more optimized than dev but much faster to compile than release +# Local "release" mode, more optimized than dev but faster to compile than release [profile.local] inherits = "dev" opt-level = 3 # Set this to 1 or 2 to get more useful backtraces -debug = 0 +debug = true +debug-assertions = false panic = 'unwind' # better recompile times incremental = true +lto = "thin" codegen-units = 16 [profile.release] @@ -104,9 +108,20 @@ lto = "fat" # codegen-units = 1 panic = "abort" +# For performance profiling +[profile.flamegraph] +inherits = "release" +debug = true + [patch."ssh://github.com/axiom-crypto/halo2-lib-working.git"] halo2_base = { path = "../halo2-lib-working/halo2_base" } halo2_ecc = { path = "../halo2-lib-working/halo2_ecc" } [patch."https://github.com/privacy-scaling-explorations/halo2curves.git"] -halo2curves = { path = "../halo2/arithmetic/curves" } \ No newline at end of file +halo2curves = { path = "../halo2/arithmetic/curves" } + +[patch."https://github.com/privacy-scaling-explorations/halo2.git"] +halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/add-serialization" } + +[patch."https://github.com/scroll-tech/halo2.git"] +halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/add-serialization" } \ No newline at end of file diff --git a/benches/standard_plonk.rs b/benches/standard_plonk.rs index edd3ad0a..189d6bdf 100644 --- a/benches/standard_plonk.rs +++ b/benches/standard_plonk.rs @@ -6,38 +6,22 @@ use ark_std::{end_timer, start_timer}; use halo2_base::halo2_proofs; use halo2_proofs::halo2curves as halo2_curves; use halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, - plonk::{create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey}, - poly::{ - commitment::{Params, ParamsProver}, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverGWC, VerifierGWC}, - strategy::AccumulatorStrategy, - }, - VerificationStrategy, - }, - transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, + halo2curves::bn256::Bn256, + poly::{commitment::Params, kzg::commitment::ParamsKZG}, }; -use itertools::Itertools; use plonk_verifier::{ - loader::{ - evm::{encode_calldata, EvmLoader, ExecutorBuilder}, - native::NativeLoader, - }, - pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, + loader::native::NativeLoader, sdk::{ - self, gen_pk, gen_proof_shplonk, gen_snark_shplonk, halo2::aggregation::AggregationCircuit, - PoseidonTranscript, POSEIDON_SPEC, + self, gen_pk, + halo2::{ + aggregation::AggregationCircuit, gen_proof_shplonk, gen_snark_shplonk, + PoseidonTranscript, POSEIDON_SPEC, + }, }, - system::halo2::{compile, transcript::evm::EvmTranscript, Config}, - verifier::{self, PlonkVerifier}, }; use rand::rngs::OsRng; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; -use std::{io::Cursor, rc::Rc}; mod application { use super::halo2_curves::bn256::Fr; @@ -213,7 +197,7 @@ fn gen_application_snark( let circuit = application::StandardPlonk::rand(OsRng); let pk = gen_pk(params, &circuit, None); - gen_snark_shplonk(params, &pk, circuit, transcript, None) + gen_snark_shplonk(params, &pk, circuit, transcript, &mut OsRng, None) } fn bench(c: &mut Criterion) { @@ -245,7 +229,15 @@ fn bench(c: &mut Criterion) { |b, &(params, pk, agg_circuit)| { b.iter(|| { let instances = agg_circuit.instances(); - gen_proof_shplonk(params, pk, agg_circuit.clone(), instances, &mut transcript, None) + gen_proof_shplonk( + params, + pk, + agg_circuit.clone(), + instances, + &mut transcript, + &mut rng, + None, + ) }) }, ); diff --git a/benches/zkevm.rs b/benches/zkevm.rs new file mode 100644 index 00000000..8fb4ab26 --- /dev/null +++ b/benches/zkevm.rs @@ -0,0 +1,140 @@ +use std::env::{set_var, var}; +use std::path::Path; + +use ark_std::{end_timer, start_timer}; +use halo2_base::halo2_proofs; +use halo2_base::utils::fs::gen_srs; +use halo2_proofs::halo2curves::bn256::Fr; +use plonk_verifier::sdk::halo2::aggregation::load_verify_circuit_degree; +use plonk_verifier::{ + loader::native::NativeLoader, + sdk::{ + self, + evm::{ + evm_verify, gen_evm_proof_gwc, gen_evm_proof_shplonk, gen_evm_verifier_gwc, + gen_evm_verifier_shplonk, + }, + gen_pk, + halo2::{ + aggregation::AggregationCircuit, gen_proof_gwc, gen_proof_shplonk, gen_snark_shplonk, + PoseidonTranscript, POSEIDON_SPEC, + }, + }, +}; +use rand::SeedableRng; +use rand_chacha::ChaCha20Rng; + +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; +use pprof::criterion::{Output, PProfProfiler}; + +pub mod zkevm { + use super::Fr; + use bus_mapping::{circuit_input_builder::CircuitsParams, mock::BlockData}; + use eth_types::geth_types::GethData; + use mock::TestContext; + use zkevm_circuits::evm_circuit::{witness::block_convert, EvmCircuit}; + + pub fn test_circuit() -> EvmCircuit { + let empty_data: GethData = + TestContext::<0, 0>::new(None, |_| {}, |_, _| {}, |b, _| b).unwrap().into(); + + let mut builder = BlockData::new_from_geth_data_with_params( + empty_data.clone(), + CircuitsParams::default(), + ) + .new_circuit_input_builder(); + + builder.handle_block(&empty_data.eth_block, &empty_data.geth_traces).unwrap(); + + let block = block_convert(&builder.block, &builder.code_db).unwrap(); + + EvmCircuit::::new(block) + } +} + +fn bench(c: &mut Criterion) { + let mut rng = ChaCha20Rng::from_entropy(); + let mut transcript = + PoseidonTranscript::::from_spec(vec![], POSEIDON_SPEC.clone()); + + // === create zkevm evm circuit snark === + let k: u32 = var("DEGREE") + .unwrap_or_else(|_| { + set_var("DEGREE", "18"); + "18".to_owned() + }) + .parse() + .unwrap(); + let circuit = zkevm::test_circuit(); + let params_app = gen_srs(k); + let pk = gen_pk(¶ms_app, &circuit, Some(Path::new("data/zkevm_evm.pkey"))); + let snark = gen_snark_shplonk( + ¶ms_app, + &pk, + circuit, + &mut transcript, + &mut rng, + Some((Path::new("data/zkevm_evm.in"), Path::new("data/zkevm_evm.pf"))), + ); + let snarks = [snark]; + // === finished zkevm evm circuit === + + // === now to do aggregation === + set_var("VERIFY_CONFIG", "./configs/bench_zkevm.config"); + let k = load_verify_circuit_degree(); + let params = gen_srs(k); + + let start1 = start_timer!(|| "Create aggregation circuit"); + let agg_circuit = AggregationCircuit::new(¶ms, snarks, &mut transcript, &mut rng); + end_timer!(start1); + + let pk = gen_pk(¶ms, &agg_circuit, None); + + let mut group = c.benchmark_group("plonk-prover"); + group.sample_size(10); + group.bench_with_input( + BenchmarkId::new("zkevm-evm-agg", k), + &(¶ms, &pk, &agg_circuit), + |b, &(params, pk, agg_circuit)| { + b.iter(|| { + let instances = agg_circuit.instances(); + gen_proof_shplonk( + params, + pk, + agg_circuit.clone(), + instances, + &mut transcript, + &mut rng, + None, + ); + }) + }, + ); + group.finish(); + + #[cfg(feature = "loader_evm")] + { + let deployment_code = + gen_evm_verifier_shplonk::(¶ms, pk.get_vk(), None); + + let start2 = start_timer!(|| "Create EVM proof"); + let proof = gen_evm_proof_shplonk( + ¶ms, + &pk, + agg_circuit.clone(), + agg_circuit.instances(), + &mut rng, + ); + end_timer!(start2); + + evm_verify(deployment_code, agg_circuit.instances(), proof); + } +} + +criterion_group! { + name = benches; + config = Criterion::default().with_profiler(PProfProfiler::new(10, Output::Flamegraph(None))); + targets = bench +} +criterion_main!(benches); diff --git a/configs/bench_zkevm.config b/configs/bench_zkevm.config new file mode 100644 index 00000000..1cda14ab --- /dev/null +++ b/configs/bench_zkevm.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":23,"num_advice":5,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 2a8e2b5e..84ed5003 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -1,6 +1,7 @@ use ark_std::{end_timer, start_timer}; use ethereum_types::Address; use halo2_base::halo2_proofs; +use halo2_base::halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK}; use halo2_proofs::halo2curves as halo2_curves; use halo2_proofs::{ dev::MockProver, @@ -18,22 +19,29 @@ use halo2_proofs::{ transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, }; use itertools::Itertools; +use plonk_verifier::pcs::kzg::Bdfg21; +use plonk_verifier::sdk::evm::{ + evm_verify, gen_evm_proof_gwc, gen_evm_proof_shplonk, gen_evm_verifier_gwc, + gen_evm_verifier_shplonk, +}; use plonk_verifier::{ loader::{ evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, native::NativeLoader, }, pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, + sdk::CircuitExt, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::{self, PlonkVerifier}, }; use rand::rngs::OsRng; +use std::path::Path; use std::{io::Cursor, rc::Rc}; const LIMBS: usize = 3; const BITS: usize = 88; -type Pcs = Kzg; +type Pcs = Kzg; type As = KzgAs; type Plonk = verifier::Plonk>; @@ -205,7 +213,7 @@ mod aggregation { use super::halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use super::halo2_proofs::{ circuit::{Cell, Layouter, SimpleFloorPlanner, Value}, - plonk::{self, Circuit, Column, ConstraintSystem, Instance}, + plonk::{self, Circuit, Column, ConstraintSystem, Instance, Selector}, poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, }; use super::{As, Plonk, BITS, LIMBS}; @@ -223,6 +231,7 @@ mod aggregation { }; */ use itertools::Itertools; + use plonk_verifier::sdk::CircuitExt; use plonk_verifier::{ loader::{self, native::NativeLoader}, pcs::{ @@ -436,23 +445,36 @@ mod aggregation { } } - pub fn accumulator_indices() -> Vec<(usize, usize)> { - (0..4 * LIMBS).map(|idx| (0, idx)).collect() + pub fn as_proof(&self) -> Value<&[u8]> { + self.as_proof.as_ref().map(Vec::as_slice) } + } - pub fn num_instance() -> Vec { + impl CircuitExt for AggregationCircuit { + fn num_instance() -> Vec { + // [..lhs, ..rhs] vec![4 * LIMBS] } - pub fn instances(&self) -> Vec> { + fn instances(&self) -> Vec> { vec![self.instances.clone()] } - pub fn as_proof(&self) -> Value<&[u8]> { - self.as_proof.as_ref().map(Vec::as_slice) + fn accumulator_indices() -> Option> { + Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) } - } + fn selectors(config: &Self::Config) -> Vec { + config + .base_field_config + .range + .gate + .basic_gates + .iter() + .map(|gate| gate.q_enable) + .collect() + } + } impl Circuit for AggregationCircuit { type Config = AggregationConfig; type FloorPlanner = SimpleFloorPlanner; @@ -513,6 +535,10 @@ mod aggregation { let rhs = rhs.assigned(); config.base_field_config.finalize(&mut loader.ctx_mut()); + #[cfg(feature = "display")] + println!("Total advice cells: {}", loader.ctx().total_advice); + #[cfg(feature = "display")] + println!("Advice columns used: {}", loader.ctx().advice_alloc[0][0].0 + 1); let instances: Vec<_> = lhs .x @@ -564,7 +590,7 @@ fn gen_proof< let proof_time = start_timer!(|| "Create proof"); let proof = { let mut transcript = TW::init(Vec::new()); - create_proof::, ProverGWC<_>, _, _, TW, _>( + create_proof::, ProverSHPLONK<_>, _, _, TW, _>( params, pk, &[circuit], @@ -579,8 +605,8 @@ fn gen_proof< let accept = { let mut transcript = TR::init(Cursor::new(proof.clone())); - VerificationStrategy::<_, VerifierGWC<_>>::finalize( - verify_proof::<_, VerifierGWC<_>, _, TR, _>( + VerificationStrategy::<_, VerifierSHPLONK<_>>::finalize( + verify_proof::<_, VerifierSHPLONK<_>, _, TR, _>( params.verifier_params(), pk.get_vk(), AccumulatorStrategy::new(params.verifier_params()), @@ -614,49 +640,6 @@ fn gen_application_snark(params: &ParamsKZG) -> aggregation::Snark { aggregation::Snark::new(protocol, circuit.instances(), proof) } -fn gen_aggregation_evm_verifier( - params: &ParamsKZG, - vk: &VerifyingKey, - num_instance: Vec, - accumulator_indices: Vec<(usize, usize)>, -) -> Vec { - let svk = params.get_g()[0].into(); - let dk = (params.g2(), params.s_g2()).into(); - let protocol = compile( - params, - vk, - Config::kzg() - .with_num_instance(num_instance.clone()) - .with_accumulator_indices(Some(accumulator_indices)), - ); - - let loader = EvmLoader::new::(); - let protocol = protocol.loaded(&loader); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); - - let instances = transcript.load_instances(num_instance); - let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); - Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); - - evm::compile_yul(&loader.yul_code()) -} - -fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { - let calldata = encode_calldata(&instances, &proof); - let success = { - let mut evm = ExecutorBuilder::default().with_gas_limit(u64::MAX.into()).build(); - - let caller = Address::from_low_u64_be(0xfe); - let verifier = evm.deploy(caller, deployment_code.into(), 0.into()).address.unwrap(); - let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); - - dbg!(result.gas_used); - - !result.reverted - }; - assert!(success); -} - fn main() { std::env::set_var("VERIFY_CONFIG", "./configs/example_evm_accumulator.config"); let params = halo2_base::utils::fs::gen_srs(21); @@ -669,18 +652,18 @@ fn main() { let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app)); let agg_circuit = aggregation::AggregationCircuit::new(¶ms, snarks); let pk = gen_pk(¶ms, &agg_circuit); - let deployment_code = gen_aggregation_evm_verifier( + let deployment_code = gen_evm_verifier_shplonk::( ¶ms, pk.get_vk(), - aggregation::AggregationCircuit::num_instance(), - aggregation::AggregationCircuit::accumulator_indices(), + Some(Path::new("evm_verifier.yul")), ); - let proof = gen_proof::<_, _, EvmTranscript, EvmTranscript>( + let proof = gen_evm_proof_shplonk( ¶ms, &pk, agg_circuit.clone(), agg_circuit.instances(), + &mut OsRng, ); evm_verify(deployment_code, agg_circuit.instances(), proof); } diff --git a/examples/recursion.rs b/examples/recursion.rs index 487163f1..86c0d415 100644 --- a/examples/recursion.rs +++ b/examples/recursion.rs @@ -5,7 +5,7 @@ use common::*; use halo2_base::halo2_proofs; use halo2_base::utils::fs::gen_srs; use halo2_proofs::{ - circuit::{AssignedCell, Layouter, SimpleFloorPlanner, Value}, + circuit::{Layouter, SimpleFloorPlanner, Value}, dev::MockProver, halo2curves::{ bn256::{Bn256, Fq, Fr, G1Affine}, @@ -773,7 +773,9 @@ mod recursion { // IMPORTANT: config.base_field_config.finalize(&mut ctx); + #[cfg(feature = "display")] dbg!(ctx.total_advice); + #[cfg(feature = "display")] println!("Advice columns used: {}", ctx.advice_alloc[0][0].0 + 1); assigned_instances.extend( diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 87bca461..588e9482 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -133,6 +133,7 @@ pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { mod halo2_lib { use crate::halo2_proofs::{ circuit::{Cell, Value}, + halo2curves::CurveAffineExt, plonk::Error, }; use crate::{ @@ -142,7 +143,6 @@ mod halo2_lib { use halo2_base::{ self, gates::{flex_gate::FlexGateConfig, GateInstructions, RangeInstructions}, - halo2_proofs::halo2curves::CurveAffineExt, utils::PrimeField, AssignedValue, QuantumCell::{Constant, Existing, Witness}, diff --git a/src/pcs/kzg/accumulator.rs b/src/pcs/kzg/accumulator.rs index 03fa21f9..efc28cd8 100644 --- a/src/pcs/kzg/accumulator.rs +++ b/src/pcs/kzg/accumulator.rs @@ -214,15 +214,15 @@ mod halo2 { mod halo2_lib { use super::*; - use halo2_base::{halo2_proofs::halo2curves::CurveAffineExt, utils::BigPrimeField}; + use halo2_base::{halo2_proofs::halo2curves::CurveAffineExt, utils::PrimeField}; use halo2_ecc::ecc::BaseFieldEccChip; impl<'a, C, const LIMBS: usize, const BITS: usize> LimbsEncodingInstructions<'a, C, LIMBS, BITS> for BaseFieldEccChip where C: CurveAffineExt, - C::ScalarExt: BigPrimeField, - C::Base: BigPrimeField, + C::ScalarExt: PrimeField, + C::Base: PrimeField, { fn assign_ec_point_from_limbs( &self, diff --git a/src/sdk.rs b/src/sdk.rs index 08104f29..a2585f79 100644 --- a/src/sdk.rs +++ b/src/sdk.rs @@ -1,66 +1,34 @@ -#![allow(clippy::type_complexity)] -use crate::cost::CostEstimation; +#![allow(clippy::let_and_return)] use crate::halo2_proofs; -use crate::pcs::MultiOpenScheme; -use crate::{ - loader::native::NativeLoader, - pcs, - poseidon::Spec, - system::halo2::{compile, Config}, - util::transcript::TranscriptWrite, - verifier::PlonkProof, - Protocol, -}; +use crate::{pcs::kzg::LimbsEncoding, verifier, Protocol}; #[cfg(feature = "display")] use ark_std::{end_timer, start_timer}; use halo2_proofs::{ - circuit::{Layouter, Value}, - dev::MockProver, + circuit::Value, halo2curves::{ bn256::{Bn256, Fr, G1Affine}, group::ff::Field, }, - plonk::{ - create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ConstraintSystem, Error, - ProvingKey, Selector, VerifyingKey, - }, - poly::{ - commitment::{Params, ParamsProver, Prover, Verifier}, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - msm::DualMSM, - multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, - strategy::{AccumulatorStrategy, GuardKZG}, - }, - VerificationStrategy, - }, + plonk::{keygen_pk, keygen_vk, Circuit, ProvingKey, Selector}, + poly::kzg::commitment::ParamsKZG, }; use itertools::Itertools; -use lazy_static::lazy_static; -use rand::SeedableRng; -use rand_chacha::ChaCha20Rng; use std::{ fs::{self, File}, io::{BufReader, BufWriter}, - iter, - marker::PhantomData, path::Path, }; +#[cfg(feature = "loader_evm")] pub mod evm; +#[cfg(feature = "loader_halo2")] pub mod halo2; -// Poseidon parameters -const T: usize = 5; -const RATE: usize = 4; -const R_F: usize = 8; -const R_P: usize = 60; +const LIMBS: usize = 3; +const BITS: usize = 88; -pub type PoseidonTranscript = - crate::system::halo2::transcript::halo2::PoseidonTranscript; -lazy_static! { - pub static ref POSEIDON_SPEC: Spec = Spec::new(R_F, R_P); -} +/// PCS be either `Kzg` or `Kzg` +pub type Plonk = verifier::Plonk>; pub struct Snark { pub protocol: Protocol, @@ -186,178 +154,6 @@ pub fn gen_pk>( } } -/// Generates a native proof using either SHPLONK or GWC proving method. Uses Poseidon for Fiat-Shamir. -/// -/// Caches the instances and proof if `path` is specified. -pub fn gen_proof<'params, C, P, V>( - params: &'params ParamsKZG, - pk: &'params ProvingKey, - circuit: C, - instances: Vec>, - transcript: &mut PoseidonTranscript>, - path: Option<(&Path, &Path)>, -) -> Vec -where - C: Circuit, - P: Prover<'params, KZGCommitmentScheme>, - V: Verifier< - 'params, - KZGCommitmentScheme, - Guard = GuardKZG<'params, Bn256>, - MSMAccumulator = DualMSM<'params, Bn256>, - >, -{ - #[cfg(debug_assertions)] - { - MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); - } - - let mut proof: Option> = None; - - if let Some((instance_path, proof_path)) = path { - let cached_instances = read_instances(instance_path); - if matches!(cached_instances, Ok(tmp) if tmp == instances) && proof_path.exists() { - #[cfg(feature = "display")] - let read_time = start_timer!(|| format!("Reading proof from {proof_path:?}")); - - proof = Some(fs::read(proof_path).unwrap()); - - #[cfg(feature = "display")] - end_timer!(read_time); - } - } - - let instances = instances.iter().map(Vec::as_slice).collect_vec(); - - let proof = proof.unwrap_or_else(|| { - #[cfg(feature = "display")] - let proof_time = start_timer!(|| "Create proof"); - - transcript.clear(); - create_proof::<_, P, _, _, _, _>( - params, - pk, - &[circuit], - &[&instances], - &mut ChaCha20Rng::from_entropy(), - transcript, - ) - .unwrap(); - let proof = transcript.stream_mut().split_off(0); - - #[cfg(feature = "display")] - end_timer!(proof_time); - - if let Some((instance_path, proof_path)) = path { - write_instances(&instances, instance_path); - fs::write(proof_path, &proof).unwrap(); - } - proof - }); - - debug_assert!({ - let mut transcript = PoseidonTranscript::::new(proof.as_slice()); - VerificationStrategy::<_, V>::finalize( - verify_proof::<_, V, _, _, _>( - params.verifier_params(), - pk.get_vk(), - AccumulatorStrategy::new(params.verifier_params()), - &[instances.as_slice()], - &mut transcript, - ) - .unwrap(), - ) - }); - - proof -} - -pub fn gen_proof_gwc>( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: C, - instances: Vec>, - transcript: &mut PoseidonTranscript>, - path: Option<(&Path, &Path)>, -) -> Vec { - gen_proof::, VerifierGWC<_>>(params, pk, circuit, instances, transcript, path) -} - -pub fn gen_proof_shplonk>( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: C, - instances: Vec>, - transcript: &mut PoseidonTranscript>, - path: Option<(&Path, &Path)>, -) -> Vec { - gen_proof::, VerifierSHPLONK<_>>( - params, pk, circuit, instances, transcript, path, - ) -} - -pub fn gen_snark<'params, ConcreteCircuit, P, V>( - params: &'params ParamsKZG, - pk: &'params ProvingKey, - circuit: ConcreteCircuit, - transcript: &mut PoseidonTranscript>, - path: Option<(&Path, &Path)>, -) -> Snark -where - ConcreteCircuit: CircuitExt, - P: Prover<'params, KZGCommitmentScheme>, - V: Verifier< - 'params, - KZGCommitmentScheme, - Guard = GuardKZG<'params, Bn256>, - MSMAccumulator = DualMSM<'params, Bn256>, - >, -{ - let protocol = compile( - params, - pk.get_vk(), - Config::kzg() - .with_num_instance(ConcreteCircuit::num_instance()) - .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), - ); - - let instances = circuit.instances(); - let proof = gen_proof::( - params, - pk, - circuit, - instances.clone(), - transcript, - path, - ); - - Snark::new(protocol, instances, proof) -} - -pub fn gen_snark_gwc>( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: ConcreteCircuit, - transcript: &mut PoseidonTranscript>, - path: Option<(&Path, &Path)>, -) -> Snark { - gen_snark::, VerifierGWC<_>>( - params, pk, circuit, transcript, path, - ) -} - -pub fn gen_snark_shplonk>( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: ConcreteCircuit, - transcript: &mut PoseidonTranscript>, - path: Option<(&Path, &Path)>, -) -> Snark { - gen_snark::, VerifierSHPLONK<_>>( - params, pk, circuit, transcript, path, - ) -} - pub fn read_instances(path: impl AsRef) -> Result>, bincode::Error> { let f = File::open(path)?; let reader = BufReader::new(f); @@ -386,82 +182,18 @@ pub fn write_instances(instances: &[&[Fr]], path: impl AsRef) { bincode::serialize_into(f, &instances).unwrap(); } -pub fn gen_dummy_snark( - params: &ParamsKZG, - vk: Option<&VerifyingKey>, -) -> Snark -where - ConcreteCircuit: CircuitExt, - MOS: MultiOpenScheme - + CostEstimation>>, -{ - struct CsProxy(PhantomData<(F, C)>); - - impl> Circuit for CsProxy { - type Config = C::Config; - type FloorPlanner = C::FloorPlanner; +#[cfg(feature = "zkevm")] +mod zkevm { + use super::CircuitExt; + use eth_types::Field; + use zkevm_circuits::evm_circuit::EvmCircuit; - fn without_witnesses(&self) -> Self { - CsProxy(PhantomData) + impl CircuitExt for EvmCircuit { + fn instances(&self) -> Vec> { + vec![] } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - C::configure(meta) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // when `C` has simple selectors, we tell `CsProxy` not to over-optimize the selectors (e.g., compressing them all into one) by turning all selectors on in the first row - // currently this only works if all simple selector columns are used in the actual circuit and there are overlaps amongst all enabled selectors (i.e., the actual circuit will not optimize constraint system further) - layouter.assign_region( - || "", - |mut region| { - for q in C::selectors(&config).iter() { - q.enable(&mut region, 0)?; - } - Ok(()) - }, - )?; - Ok(()) + fn num_instance() -> Vec { + vec![] } } - - let dummy_vk = vk - .is_none() - .then(|| keygen_vk(params, &CsProxy::(PhantomData)).unwrap()); - let protocol = compile( - params, - vk.or(dummy_vk.as_ref()).unwrap(), - Config::kzg() - .with_num_instance(ConcreteCircuit::num_instance()) - .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), - ); - let instances = ConcreteCircuit::num_instance() - .into_iter() - .map(|n| iter::repeat(Fr::default()).take(n).collect()) - .collect(); - let proof = { - let mut transcript = PoseidonTranscript::::new(Vec::new()); - for _ in 0..protocol - .num_witness - .iter() - .chain(Some(&protocol.quotient.num_chunk())) - .sum::() - { - transcript.write_ec_point(G1Affine::default()).unwrap(); - } - for _ in 0..protocol.evaluations.len() { - transcript.write_scalar(Fr::default()).unwrap(); - } - let queries = PlonkProof::::empty_queries(&protocol); - for _ in 0..MOS::estimate_cost(&queries).num_commitment { - transcript.write_ec_point(G1Affine::default()).unwrap(); - } - transcript.finalize() - }; - - Snark::new(protocol, instances, proof) } diff --git a/src/sdk/evm.rs b/src/sdk/evm.rs index e69de29b..e6154ee2 100644 --- a/src/sdk/evm.rs +++ b/src/sdk/evm.rs @@ -0,0 +1,191 @@ +use super::CircuitExt; +use crate::{ + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::{ + commitment::{Params, ParamsProver, Prover, Verifier}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + msm::DualMSM, + multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, + strategy::{AccumulatorStrategy, GuardKZG}, + }, + VerificationStrategy, + }, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, + }, + pcs::kzg::{Bdfg21, Gwc19, Kzg}, + sdk::Plonk, +}; +use crate::{ + loader::evm::{compile_yul, encode_calldata, EvmLoader, ExecutorBuilder}, + pcs::{ + kzg::{KzgAccumulator, KzgDecidingKey, KzgSuccinctVerifyingKey}, + Decider, MultiOpenScheme, PolynomialCommitmentScheme, + }, + system::halo2::{compile, transcript::evm::EvmTranscript, Config}, + verifier::PlonkVerifier, +}; +use ethereum_types::Address; +use itertools::Itertools; +use rand::Rng; +use std::{fs, io, path::Path, rc::Rc}; + +/// Generates a proof for evm verification using either SHPLONK or GWC proving method. Uses Keccak for Fiat-Shamir. +pub fn gen_evm_proof<'params, C, P, V>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: C, + instances: Vec>, + rng: &mut impl Rng, +) -> Vec +where + C: Circuit, + P: Prover<'params, KZGCommitmentScheme>, + V: Verifier< + 'params, + KZGCommitmentScheme, + Guard = GuardKZG<'params, Bn256>, + MSMAccumulator = DualMSM<'params, Bn256>, + >, +{ + #[cfg(debug_assertions)] + { + MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); + } + + let instances = instances.iter().map(|instances| instances.as_slice()).collect_vec(); + let proof = { + let mut transcript = TranscriptWriterBuffer::<_, G1Affine, _>::init(Vec::new()); + create_proof::, P, _, _, EvmTranscript<_, _, _, _>, _>( + params, + pk, + &[circuit], + &[instances.as_slice()], + rng, + &mut transcript, + ) + .unwrap(); + transcript.finalize() + }; + + let accept = { + let mut transcript = TranscriptReadBuffer::<_, G1Affine, _>::init(proof.as_slice()); + VerificationStrategy::<_, V>::finalize( + verify_proof::<_, V, _, EvmTranscript<_, _, _, _>, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }; + assert!(accept); + + proof +} + +pub fn gen_evm_proof_gwc<'params, C: Circuit>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: C, + instances: Vec>, + rng: &mut impl Rng, +) -> Vec { + gen_evm_proof::, VerifierGWC<_>>(params, pk, circuit, instances, rng) +} + +pub fn gen_evm_proof_shplonk<'params, C: Circuit>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: C, + instances: Vec>, + rng: &mut impl Rng, +) -> Vec { + gen_evm_proof::, VerifierSHPLONK<_>>(params, pk, circuit, instances, rng) +} + +pub fn gen_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + path: Option<&Path>, +) -> Vec +where + C: CircuitExt, + PCS: PolynomialCommitmentScheme< + G1Affine, + Rc, + Accumulator = KzgAccumulator>, + > + MultiOpenScheme< + G1Affine, + Rc, + SuccinctVerifyingKey = KzgSuccinctVerifyingKey, + > + Decider, DecidingKey = KzgDecidingKey>, +{ + let svk = params.get_g()[0].into(); + let dk = (params.g2(), params.s_g2()).into(); + let protocol = compile( + params, + vk, + Config::kzg() + .with_num_instance(C::num_instance()) + .with_accumulator_indices(C::accumulator_indices()), + ); + + let loader = EvmLoader::new::(); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); + + let instances = transcript.load_instances(C::num_instance()); + let proof = Plonk::::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); + + let yul_code = loader.yul_code(); + let byte_code = compile_yul(&yul_code); + if let Some(path) = path { + path.parent().and_then(|dir| fs::create_dir_all(dir).ok()).unwrap(); + fs::write(path, yul_code).unwrap(); + } + byte_code +} + +pub fn gen_evm_verifier_gwc>( + params: &ParamsKZG, + vk: &VerifyingKey, + path: Option<&Path>, +) -> Vec { + gen_evm_verifier::>(params, vk, path) +} + +pub fn gen_evm_verifier_shplonk>( + params: &ParamsKZG, + vk: &VerifyingKey, + path: Option<&Path>, +) -> Vec { + gen_evm_verifier::>(params, vk, path) +} + +pub fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { + let calldata = encode_calldata(&instances, &proof); + let success = { + let mut evm = ExecutorBuilder::default().with_gas_limit(u64::MAX.into()).build(); + + let caller = Address::from_low_u64_be(0xfe); + let verifier = evm.deploy(caller, deployment_code.into(), 0.into()).address.unwrap(); + let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); + + dbg!(result.gas_used); + + !result.reverted + }; + assert!(success); +} + +pub fn write_calldata(instances: &[Vec], proof: &[u8], path: &Path) -> io::Result<()> { + let calldata = encode_calldata(instances, proof); + fs::write(path, hex::encode(calldata)) +} diff --git a/src/sdk/halo2.rs b/src/sdk/halo2.rs index e5e5a4a5..40a65ad7 100644 --- a/src/sdk/halo2.rs +++ b/src/sdk/halo2.rs @@ -1 +1,324 @@ +use super::{read_instances, write_instances, CircuitExt, Snark, SnarkWitness}; +use crate::cost::CostEstimation; +use crate::halo2_proofs; +use crate::pcs::MultiOpenScheme; +use crate::{ + loader::native::NativeLoader, + pcs, + poseidon::Spec, + system::halo2::{compile, Config}, + util::transcript::TranscriptWrite, + verifier::PlonkProof, +}; +#[cfg(feature = "display")] +use ark_std::{end_timer, start_timer}; +use halo2_proofs::{ + circuit::Layouter, + dev::MockProver, + halo2curves::{ + bn256::{Bn256, Fr, G1Affine}, + group::ff::Field, + }, + plonk::{ + create_proof, keygen_vk, verify_proof, Circuit, ConstraintSystem, Error, ProvingKey, + VerifyingKey, + }, + poly::{ + commitment::{Params, ParamsProver, Prover, Verifier}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + msm::DualMSM, + multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, + strategy::{AccumulatorStrategy, GuardKZG}, + }, + VerificationStrategy, + }, +}; +use itertools::Itertools; +use lazy_static::lazy_static; +use rand::Rng; +use std::{fs, iter, marker::PhantomData, path::Path}; + pub mod aggregation; + +// Poseidon parameters +const T: usize = 5; +const RATE: usize = 4; +const R_F: usize = 8; +const R_P: usize = 60; + +pub type PoseidonTranscript = + crate::system::halo2::transcript::halo2::PoseidonTranscript; + +lazy_static! { + pub static ref POSEIDON_SPEC: Spec = Spec::new(R_F, R_P); +} + +/// Generates a native proof using either SHPLONK or GWC proving method. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_proof<'params, C, P, V>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: C, + instances: Vec>, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Vec +where + C: Circuit, + P: Prover<'params, KZGCommitmentScheme>, + V: Verifier< + 'params, + KZGCommitmentScheme, + Guard = GuardKZG<'params, Bn256>, + MSMAccumulator = DualMSM<'params, Bn256>, + >, +{ + #[cfg(debug_assertions)] + { + MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); + } + + let mut proof: Option> = None; + + if let Some((instance_path, proof_path)) = path { + let cached_instances = read_instances(instance_path); + if matches!(cached_instances, Ok(tmp) if tmp == instances) && proof_path.exists() { + #[cfg(feature = "display")] + let read_time = start_timer!(|| format!("Reading proof from {proof_path:?}")); + + proof = Some(fs::read(proof_path).unwrap()); + + #[cfg(feature = "display")] + end_timer!(read_time); + } + } + + let instances = instances.iter().map(Vec::as_slice).collect_vec(); + + let proof = proof.unwrap_or_else(|| { + #[cfg(feature = "display")] + let proof_time = start_timer!(|| "Create proof"); + + transcript.clear(); + create_proof::<_, P, _, _, _, _>(params, pk, &[circuit], &[&instances], rng, transcript) + .unwrap(); + let proof = transcript.stream_mut().split_off(0); + + #[cfg(feature = "display")] + end_timer!(proof_time); + + if let Some((instance_path, proof_path)) = path { + write_instances(&instances, instance_path); + fs::write(proof_path, &proof).unwrap(); + } + proof + }); + + debug_assert!({ + let mut transcript = PoseidonTranscript::::new(proof.as_slice()); + VerificationStrategy::<_, V>::finalize( + verify_proof::<_, V, _, _, _>( + params.verifier_params(), + pk.get_vk(), + AccumulatorStrategy::new(params.verifier_params()), + &[instances.as_slice()], + &mut transcript, + ) + .unwrap(), + ) + }); + + proof +} + +/// Generates a native proof using original Plonk (GWC '19) multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_proof_gwc>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Vec { + gen_proof::, VerifierGWC<_>>( + params, pk, circuit, instances, transcript, rng, path, + ) +} + +/// Generates a native proof using SHPLONK multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_proof_shplonk>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: C, + instances: Vec>, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Vec { + gen_proof::, VerifierSHPLONK<_>>( + params, pk, circuit, instances, transcript, rng, path, + ) +} + +/// Generates a SNARK using either SHPLONK or GWC multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_snark<'params, ConcreteCircuit, P, V>( + params: &'params ParamsKZG, + pk: &'params ProvingKey, + circuit: ConcreteCircuit, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Snark +where + ConcreteCircuit: CircuitExt, + P: Prover<'params, KZGCommitmentScheme>, + V: Verifier< + 'params, + KZGCommitmentScheme, + Guard = GuardKZG<'params, Bn256>, + MSMAccumulator = DualMSM<'params, Bn256>, + >, +{ + let protocol = compile( + params, + pk.get_vk(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + + let instances = circuit.instances(); + let proof = gen_proof::( + params, + pk, + circuit, + instances.clone(), + transcript, + rng, + path, + ); + + Snark::new(protocol, instances, proof) +} + +/// Generates a SNARK using GWC multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_snark_gwc>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Snark { + gen_snark::, VerifierGWC<_>>( + params, pk, circuit, transcript, rng, path, + ) +} + +/// Generates a SNARK using SHPLONK multi-open scheme. Uses Poseidon for Fiat-Shamir. +/// +/// Caches the instances and proof if `path` is specified. +pub fn gen_snark_shplonk>( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: ConcreteCircuit, + transcript: &mut PoseidonTranscript>, + rng: &mut impl Rng, + path: Option<(&Path, &Path)>, +) -> Snark { + gen_snark::, VerifierSHPLONK<_>>( + params, pk, circuit, transcript, rng, path, + ) +} + +pub fn gen_dummy_snark( + params: &ParamsKZG, + vk: Option<&VerifyingKey>, +) -> Snark +where + ConcreteCircuit: CircuitExt, + MOS: MultiOpenScheme + + CostEstimation>>, +{ + struct CsProxy(PhantomData<(F, C)>); + + impl> Circuit for CsProxy { + type Config = C::Config; + type FloorPlanner = C::FloorPlanner; + + fn without_witnesses(&self) -> Self { + CsProxy(PhantomData) + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + C::configure(meta) + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // when `C` has simple selectors, we tell `CsProxy` not to over-optimize the selectors (e.g., compressing them all into one) by turning all selectors on in the first row + // currently this only works if all simple selector columns are used in the actual circuit and there are overlaps amongst all enabled selectors (i.e., the actual circuit will not optimize constraint system further) + layouter.assign_region( + || "", + |mut region| { + for q in C::selectors(&config).iter() { + q.enable(&mut region, 0)?; + } + Ok(()) + }, + )?; + Ok(()) + } + } + + let dummy_vk = vk + .is_none() + .then(|| keygen_vk(params, &CsProxy::(PhantomData)).unwrap()); + let protocol = compile( + params, + vk.or(dummy_vk.as_ref()).unwrap(), + Config::kzg() + .with_num_instance(ConcreteCircuit::num_instance()) + .with_accumulator_indices(ConcreteCircuit::accumulator_indices()), + ); + let instances = ConcreteCircuit::num_instance() + .into_iter() + .map(|n| iter::repeat(Fr::default()).take(n).collect()) + .collect(); + let proof = { + let mut transcript = PoseidonTranscript::::new(Vec::new()); + for _ in 0..protocol + .num_witness + .iter() + .chain(Some(&protocol.quotient.num_chunk())) + .sum::() + { + transcript.write_ec_point(G1Affine::default()).unwrap(); + } + for _ in 0..protocol.evaluations.len() { + transcript.write_scalar(Fr::default()).unwrap(); + } + let queries = PlonkProof::::empty_queries(&protocol); + for _ in 0..MOS::estimate_cost(&queries).num_commitment { + transcript.write_ec_point(G1Affine::default()).unwrap(); + } + transcript.finalize() + }; + + Snark::new(protocol, instances, proof) +} diff --git a/src/sdk/halo2/aggregation.rs b/src/sdk/halo2/aggregation.rs index e1766c4c..d5400a41 100644 --- a/src/sdk/halo2/aggregation.rs +++ b/src/sdk/halo2/aggregation.rs @@ -1,43 +1,44 @@ use crate::halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use crate::halo2_proofs::{ - circuit::{Cell, Layouter, SimpleFloorPlanner, Value}, - plonk::{self, Circuit, Column, ConstraintSystem, Instance}, + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{self, Circuit, Column, ConstraintSystem, Instance, Selector}, poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, }; -use crate::pcs::kzg::Bdfg21; use crate::{ loader::{self, native::NativeLoader}, pcs::{ - kzg::{Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, + kzg::{Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey}, AccumulationScheme, AccumulationSchemeProver, MultiOpenScheme, PolynomialCommitmentScheme, }, - system, + sdk::{Plonk, BITS, LIMBS}, util::arithmetic::fe_to_limbs, - verifier::{self, PlonkVerifier}, - Protocol, + verifier::PlonkVerifier, }; #[cfg(feature = "display")] use ark_std::{end_timer, start_timer}; -use halo2_base::{AssignedValue, Context, ContextParams}; +use halo2_base::{Context, ContextParams}; use halo2_ecc::ecc::EccChip; use itertools::Itertools; -use rand::rngs::OsRng; -use rand::{Rng, SeedableRng}; -use rand_chacha::ChaCha20Rng; +use rand::Rng; use std::{fs::File, rc::Rc}; -const LIMBS: usize = 3; -const BITS: usize = 88; - -use crate::sdk::{PoseidonTranscript, Snark, SnarkWitness, POSEIDON_SPEC}; +use super::{CircuitExt, PoseidonTranscript, Snark, SnarkWitness, POSEIDON_SPEC}; type Svk = KzgSuccinctVerifyingKey; type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; -/// PCS be either `Kzg` or `Kzg` -type Plonk = verifier::Plonk>; type Shplonk = Plonk>; +pub fn load_verify_circuit_degree() -> u32 { + let path = std::env::var("VERIFY_CONFIG") + .unwrap_or_else(|_| "./configs/verify_circuit.config".to_string()); + let params: AggregationConfigParams = serde_json::from_reader( + File::open(path.as_str()).unwrap_or_else(|_| panic!("{path} does not exist")), + ) + .unwrap(); + params.degree +} + /// Core function used in `synthesize` to aggregate multiple `snarks`. /// /// Returns the assigned instances of previous snarks (all concatenated together) and the new final pair that needs to be verified in a pairing check @@ -95,7 +96,7 @@ where }) .collect_vec(); - let acccumulator = if accumulators.len() > 1 { + let accumulator = if accumulators.len() > 1 { transcript.new_stream(as_proof); let proof = KzgAs::::read_proof(&Default::default(), &accumulators, &mut transcript).unwrap(); @@ -104,7 +105,7 @@ where accumulators.pop().unwrap() }; - (previous_instances, acccumulator) + (previous_instances, accumulator) } #[derive(serde::Serialize, serde::Deserialize)] @@ -244,6 +245,25 @@ impl AggregationCircuit { } } +impl CircuitExt for AggregationCircuit { + fn num_instance() -> Vec { + // [..lhs, ..rhs] + vec![4 * LIMBS] + } + + fn instances(&self) -> Vec> { + vec![self.instances.clone()] + } + + fn accumulator_indices() -> Option> { + Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) + } + + fn selectors(config: &Self::Config) -> Vec { + config.gate().basic_gates.iter().map(|gate| gate.q_enable).collect() + } +} + impl Circuit for AggregationCircuit { type Config = AggregationConfig; type FloorPlanner = SimpleFloorPlanner; @@ -323,7 +343,16 @@ impl Circuit for AggregationCircuit { .chain(lhs.y.truncation.limbs.iter()) .chain(rhs.x.truncation.limbs.iter()) .chain(rhs.y.truncation.limbs.iter()) - .map(|assigned| assigned.cell().clone()) + .map(|assigned| { + #[cfg(feature = "halo2-axiom")] + { + *assigned.cell() + } + #[cfg(feature = "halo2-pse")] + { + assigned.cell() + } + }) .collect_vec(); #[cfg(feature = "display")] end_timer!(witness_time); diff --git a/src/system.rs b/src/system.rs index 5d5aa99c..edf79228 100644 --- a/src/system.rs +++ b/src/system.rs @@ -1,2 +1 @@ -#[cfg(feature = "system_halo2")] pub mod halo2; diff --git a/src/system/halo2.rs b/src/system/halo2.rs index 491ae14f..2c537f10 100644 --- a/src/system/halo2.rs +++ b/src/system/halo2.rs @@ -20,6 +20,7 @@ use std::{io, iter, mem::size_of}; pub mod transcript; #[cfg(test)] +#[cfg(feature = "loader_halo2")] pub(crate) mod test; #[derive(Clone, Debug, Default)] diff --git a/src/system/halo2/test/circuit/standard.rs b/src/system/halo2/test/circuit/standard.rs index 5bbdeb4f..bfa94df4 100644 --- a/src/system/halo2/test/circuit/standard.rs +++ b/src/system/halo2/test/circuit/standard.rs @@ -88,9 +88,9 @@ impl Circuit for StandardPlonk { #[cfg(feature = "halo2-pse")] { region.assign_advice(|| "", config.a, 0, || Value::known(self.0))?; - region.assign_fixed(|| "", config.q_a, 0, || Value::known(-Fr::one()))?; + region.assign_fixed(|| "", config.q_a, 0, || Value::known(-F::one()))?; - region.assign_advice(|| "", config.a, 1, || Value::known(-Fr::from(5u64)))?; + region.assign_advice(|| "", config.a, 1, || Value::known(-F::from(5u64)))?; for (idx, column) in (1..).zip([ config.q_a, config.q_b, @@ -102,11 +102,11 @@ impl Circuit for StandardPlonk { || "", column, 1, - || Value::known(Fr::from(idx as u64)), + || Value::known(F::from(idx as u64)), )?; } - let a = region.assign_advice(|| "", config.a, 2, || Value::known(Fr::one()))?; + let a = region.assign_advice(|| "", config.a, 2, || Value::known(F::one()))?; a.copy_advice(|| "", &mut region, config.b, 3)?; a.copy_advice(|| "", &mut region, config.c, 4)?; } diff --git a/src/system/halo2/transcript/evm.rs b/src/system/halo2/transcript/evm.rs index 32d60b8f..909bb71d 100644 --- a/src/system/halo2/transcript/evm.rs +++ b/src/system/halo2/transcript/evm.rs @@ -1,3 +1,4 @@ +use crate::halo2_proofs; use crate::{ loader::{ evm::{loader::Value, u256_to_fe, EcPoint, EvmLoader, MemoryChunk, Scalar}, @@ -37,12 +38,7 @@ where assert_eq!(ptr, 0); let mut buf = MemoryChunk::new(ptr); buf.extend(0x20); - Self { - loader: loader.clone(), - stream: 0, - buf, - _marker: PhantomData, - } + Self { loader: loader.clone(), stream: 0, buf, _marker: PhantomData } } pub fn load_instances(&mut self, num_instance: Vec) -> Vec> { @@ -149,12 +145,7 @@ where C: CurveAffine, { pub fn new(stream: S) -> Self { - Self { - loader: NativeLoader, - stream, - buf: Vec::new(), - _marker: PhantomData, - } + Self { loader: NativeLoader, stream, buf: Vec::new(), _marker: PhantomData } } } @@ -172,11 +163,7 @@ where .buf .iter() .cloned() - .chain(if self.buf.len() == 0x20 { - Some(1) - } else { - None - }) + .chain(if self.buf.len() == 0x20 { Some(1) } else { None }) .collect_vec(); let hash: [u8; 32] = Keccak256::digest(data).into(); self.buf = hash.to_vec(); @@ -193,8 +180,7 @@ where })?; [coordinates.x(), coordinates.y()].map(|coordinate| { - self.buf - .extend(coordinate.to_repr().as_ref().iter().rev().cloned()); + self.buf.extend(coordinate.to_repr().as_ref().iter().rev().cloned()); }); Ok(()) @@ -220,10 +206,7 @@ where .map_err(|err| Error::Transcript(err.kind(), err.to_string()))?; data.reverse(); let scalar = C::Scalar::from_repr_vartime(data).ok_or_else(|| { - Error::Transcript( - io::ErrorKind::Other, - "Invalid scalar encoding in proof".to_string(), - ) + Error::Transcript(io::ErrorKind::Other, "Invalid scalar encoding in proof".to_string()) })?; self.common_scalar(&scalar)?; Ok(scalar) @@ -239,10 +222,8 @@ where } let x = Option::from(::from_repr(x)); let y = Option::from(::from_repr(y)); - let ec_point = x - .zip(y) - .and_then(|(x, y)| Option::from(C::from_xy(x, y))) - .ok_or_else(|| { + let ec_point = + x.zip(y).and_then(|(x, y)| Option::from(C::from_xy(x, y))).ok_or_else(|| { Error::Transcript( io::ErrorKind::Other, "Invalid elliptic curve point encoding in proof".to_string(), diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs index 2b564648..5e343740 100644 --- a/src/system/halo2/transcript/halo2.rs +++ b/src/system/halo2/transcript/halo2.rs @@ -404,13 +404,13 @@ where mod halo2_lib { use crate::halo2_curves::CurveAffineExt; use crate::system::halo2::transcript::halo2::NativeEncoding; - use halo2_base::utils::BigPrimeField; + use halo2_base::utils::PrimeField; use halo2_ecc::ecc::BaseFieldEccChip; impl<'a, C: CurveAffineExt> NativeEncoding<'a, C> for BaseFieldEccChip where - C::Scalar: BigPrimeField, - C::Base: BigPrimeField, + C::Scalar: PrimeField, + C::Base: PrimeField, { fn encode( &self, From 71f2b6f4bc24c99aff7b3718fc6e89d5034dfbc8 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Tue, 6 Dec 2022 10:29:14 -0500 Subject: [PATCH 27/28] chore: update Cargo.toml --- Cargo.toml | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d73c60f5..10d8f406 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,8 +17,8 @@ serde_json = "1.0" bincode = "1.3.3" ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } -# Use halo2_base as non-optional dependency because it re-exports halo2_proofs, halo2curves, and poseidon, using different repos based on feature flag "halo2-axiom" or "halo2-pse" -halo2_base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_base", default-features = false } +# Use halo2-base as non-optional dependency because it re-exports halo2_proofs, halo2curves, and poseidon, using different repos based on feature flag "halo2-axiom" or "halo2-pse" +halo2-base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "experiment/optimizations", default-features = false } # parallel rayon = { version = "1.5.3", optional = true } @@ -31,13 +31,13 @@ bytes = { version = "1.2", optional = true } rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } # loader_halo2 -halo2_ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", tag = "v0.2.1", package = "halo2_ecc", default-features = false, optional = true } +halo2-ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "experiment/optimizations", default-features = false, optional = true } # zkevm benchmarks -zkevm-circuits = { path = "../zkevm-circuits/zkevm-circuits", features = ["test"], optional = true } -bus-mapping = { path = "../zkevm-circuits/bus-mapping", optional = true } -eth-types = { path = "../zkevm-circuits/eth-types", optional = true } -mock = { path = "../zkevm-circuits/mock", optional = true } +zkevm-circuits = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", features = ["test"], optional = true } +bus-mapping = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } +eth-types = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } +mock = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } @@ -50,16 +50,16 @@ crossterm = { version = "0.25" } tui = { version = "0.19", default-features = false, features = ["crossterm"] } [features] -default = ["loader_evm", "loader_halo2", "zkevm", "halo2-pse", "halo2_ecc?/jemallocator"] -display = ["halo2_base/display", "halo2_ecc?/display", "ark-std"] +default = ["loader_evm", "loader_halo2", "zkevm", "halo2-pse", "halo2-ecc?/jemallocator"] +display = ["halo2-base/display", "halo2-ecc?/display", "ark-std"] # EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo -halo2-pse = ["halo2_base/halo2-pse", "halo2_ecc?/halo2-pse"] -halo2-axiom = ["halo2_base/halo2-axiom", "halo2_ecc?/halo2-axiom"] +halo2-pse = ["halo2-base/halo2-pse", "halo2-ecc?/halo2-pse"] +halo2-axiom = ["halo2-base/halo2-axiom", "halo2-ecc?/halo2-axiom"] parallel = ["dep:rayon"] loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] -loader_halo2 = ["halo2_ecc"] +loader_halo2 = ["halo2-ecc"] zkevm = ["dep:zkevm-circuits", "dep:bus-mapping", "dep:mock", "dep:eth-types"] @@ -102,11 +102,13 @@ lto = "thin" codegen-units = 16 [profile.release] -debug = 0 opt-level = 3 +debug = false +debug-assertions = false lto = "fat" # codegen-units = 1 panic = "abort" +incremental = false # For performance profiling [profile.flamegraph] @@ -114,14 +116,11 @@ inherits = "release" debug = true [patch."ssh://github.com/axiom-crypto/halo2-lib-working.git"] -halo2_base = { path = "../halo2-lib-working/halo2_base" } -halo2_ecc = { path = "../halo2-lib-working/halo2_ecc" } +halo2-base = { path = "../halo2-lib-working/halo2-base" } +halo2-ecc = { path = "../halo2-lib-working/halo2-ecc" } [patch."https://github.com/privacy-scaling-explorations/halo2curves.git"] halo2curves = { path = "../halo2/arithmetic/curves" } [patch."https://github.com/privacy-scaling-explorations/halo2.git"] halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/add-serialization" } - -[patch."https://github.com/scroll-tech/halo2.git"] -halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/add-serialization" } \ No newline at end of file From b3ad21fa57a5e5e130cca9632e54876983d25938 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Thu, 15 Dec 2022 16:19:24 -0500 Subject: [PATCH 28/28] reorg: rename repo to snark-verifier * split into two crates: snark-verifier and snark-verifier-sdk * previous module sdk moved to snark-verifier-sdk --- Cargo.toml | 108 +++------------ snark-verifier-sdk/Cargo.toml | 66 +++++++++ .../benches}/standard_plonk.rs | 40 ++---- .../benches}/zkevm.rs | 34 ++--- .../configs}/bench_zkevm.config | 0 .../configs/example_evm_accumulator.config | 1 + .../configs}/verify_circuit.config | 0 {src/sdk => snark-verifier-sdk/src}/evm.rs | 42 +++--- {src/sdk => snark-verifier-sdk/src}/halo2.rs | 30 ++-- .../src}/halo2/aggregation.rs | 31 +++-- src/sdk.rs => snark-verifier-sdk/src/lib.rs | 5 +- snark-verifier/Cargo.toml | 61 +++++++++ .../configs/example_evm_accumulator.config | 1 + .../configs/verify_circuit.config | 0 .../examples}/README.md | 2 +- .../evm-verifier-with-accumulator.rs | 128 +++++++++--------- .../examples}/evm-verifier.rs | 4 +- .../examples}/recursion.rs | 23 ++-- {src => snark-verifier/src}/cost.rs | 0 {src => snark-verifier/src}/lib.rs | 1 - {src => snark-verifier/src}/loader.rs | 0 {src => snark-verifier/src}/loader/evm.rs | 0 .../src}/loader/evm/code.rs | 0 .../src}/loader/evm/loader.rs | 0 .../src}/loader/evm/test.rs | 0 .../src}/loader/evm/test/tui.rs | 0 .../src}/loader/evm/util.rs | 0 .../src}/loader/evm/util/executor.rs | 0 {src => snark-verifier/src}/loader/halo2.rs | 2 +- .../src}/loader/halo2/loader.rs | 0 .../src}/loader/halo2/shim.rs | 0 .../src}/loader/halo2/test.rs | 0 {src => snark-verifier/src}/loader/native.rs | 0 {src => snark-verifier/src}/pcs.rs | 0 {src => snark-verifier/src}/pcs/ipa.rs | 0 .../src}/pcs/ipa/accumulation.rs | 0 .../src}/pcs/ipa/accumulator.rs | 0 .../src}/pcs/ipa/decider.rs | 0 .../src}/pcs/ipa/multiopen.rs | 0 .../src}/pcs/ipa/multiopen/bgh19.rs | 0 {src => snark-verifier/src}/pcs/kzg.rs | 0 .../src}/pcs/kzg/accumulation.rs | 0 .../src}/pcs/kzg/accumulator.rs | 0 .../src}/pcs/kzg/decider.rs | 0 .../src}/pcs/kzg/multiopen.rs | 0 .../src}/pcs/kzg/multiopen/bdfg21.rs | 0 .../src}/pcs/kzg/multiopen/gwc19.rs | 0 {src => snark-verifier/src}/system.rs | 0 {src => snark-verifier/src}/system/halo2.rs | 2 +- .../src}/system/halo2/aggregation.rs | 0 .../src}/system/halo2/strategy.rs | 0 .../src}/system/halo2/test.rs | 0 .../src}/system/halo2/test/circuit.rs | 0 .../system/halo2/test/circuit/maingate.rs | 0 .../system/halo2/test/circuit/standard.rs | 0 .../src}/system/halo2/test/ipa.rs | 0 .../src}/system/halo2/test/ipa/native.rs | 0 .../src}/system/halo2/test/kzg.rs | 0 .../src}/system/halo2/test/kzg/evm.rs | 0 .../src}/system/halo2/test/kzg/halo2.rs | 2 +- .../src}/system/halo2/test/kzg/native.rs | 0 .../src}/system/halo2/transcript.rs | 0 .../src}/system/halo2/transcript/evm.rs | 0 .../src}/system/halo2/transcript/halo2.rs | 0 {src => snark-verifier/src}/util.rs | 0 .../src}/util/arithmetic.rs | 0 {src => snark-verifier/src}/util/hash.rs | 0 .../src}/util/hash/poseidon.rs | 0 {src => snark-verifier/src}/util/msm.rs | 0 {src => snark-verifier/src}/util/poly.rs | 0 {src => snark-verifier/src}/util/protocol.rs | 0 .../src}/util/transcript.rs | 0 {src => snark-verifier/src}/verifier.rs | 0 {src => snark-verifier/src}/verifier/plonk.rs | 0 src/system/halo2/test/README.md | 19 --- 75 files changed, 314 insertions(+), 288 deletions(-) create mode 100644 snark-verifier-sdk/Cargo.toml rename {benches => snark-verifier-sdk/benches}/standard_plonk.rs (88%) rename {benches => snark-verifier-sdk/benches}/zkevm.rs (88%) rename {configs => snark-verifier-sdk/configs}/bench_zkevm.config (100%) create mode 100644 snark-verifier-sdk/configs/example_evm_accumulator.config rename {configs => snark-verifier-sdk/configs}/verify_circuit.config (100%) rename {src/sdk => snark-verifier-sdk/src}/evm.rs (87%) rename {src/sdk => snark-verifier-sdk/src}/halo2.rs (97%) rename {src/sdk => snark-verifier-sdk/src}/halo2/aggregation.rs (95%) rename src/sdk.rs => snark-verifier-sdk/src/lib.rs (97%) create mode 100644 snark-verifier/Cargo.toml create mode 100644 snark-verifier/configs/example_evm_accumulator.config rename configs/example_evm_accumulator.config => snark-verifier/configs/verify_circuit.config (100%) rename {examples => snark-verifier/examples}/README.md (77%) rename {examples => snark-verifier/examples}/evm-verifier-with-accumulator.rs (89%) rename {examples => snark-verifier/examples}/evm-verifier.rs (99%) rename {examples => snark-verifier/examples}/recursion.rs (98%) rename {src => snark-verifier/src}/cost.rs (100%) rename {src => snark-verifier/src}/lib.rs (99%) rename {src => snark-verifier/src}/loader.rs (100%) rename {src => snark-verifier/src}/loader/evm.rs (100%) rename {src => snark-verifier/src}/loader/evm/code.rs (100%) rename {src => snark-verifier/src}/loader/evm/loader.rs (100%) rename {src => snark-verifier/src}/loader/evm/test.rs (100%) rename {src => snark-verifier/src}/loader/evm/test/tui.rs (100%) rename {src => snark-verifier/src}/loader/evm/util.rs (100%) rename {src => snark-verifier/src}/loader/evm/util/executor.rs (100%) rename {src => snark-verifier/src}/loader/halo2.rs (98%) rename {src => snark-verifier/src}/loader/halo2/loader.rs (100%) rename {src => snark-verifier/src}/loader/halo2/shim.rs (100%) rename {src => snark-verifier/src}/loader/halo2/test.rs (100%) rename {src => snark-verifier/src}/loader/native.rs (100%) rename {src => snark-verifier/src}/pcs.rs (100%) rename {src => snark-verifier/src}/pcs/ipa.rs (100%) rename {src => snark-verifier/src}/pcs/ipa/accumulation.rs (100%) rename {src => snark-verifier/src}/pcs/ipa/accumulator.rs (100%) rename {src => snark-verifier/src}/pcs/ipa/decider.rs (100%) rename {src => snark-verifier/src}/pcs/ipa/multiopen.rs (100%) rename {src => snark-verifier/src}/pcs/ipa/multiopen/bgh19.rs (100%) rename {src => snark-verifier/src}/pcs/kzg.rs (100%) rename {src => snark-verifier/src}/pcs/kzg/accumulation.rs (100%) rename {src => snark-verifier/src}/pcs/kzg/accumulator.rs (100%) rename {src => snark-verifier/src}/pcs/kzg/decider.rs (100%) rename {src => snark-verifier/src}/pcs/kzg/multiopen.rs (100%) rename {src => snark-verifier/src}/pcs/kzg/multiopen/bdfg21.rs (100%) rename {src => snark-verifier/src}/pcs/kzg/multiopen/gwc19.rs (100%) rename {src => snark-verifier/src}/system.rs (100%) rename {src => snark-verifier/src}/system/halo2.rs (99%) rename {src => snark-verifier/src}/system/halo2/aggregation.rs (100%) rename {src => snark-verifier/src}/system/halo2/strategy.rs (100%) rename {src => snark-verifier/src}/system/halo2/test.rs (100%) rename {src => snark-verifier/src}/system/halo2/test/circuit.rs (100%) rename {src => snark-verifier/src}/system/halo2/test/circuit/maingate.rs (100%) rename {src => snark-verifier/src}/system/halo2/test/circuit/standard.rs (100%) rename {src => snark-verifier/src}/system/halo2/test/ipa.rs (100%) rename {src => snark-verifier/src}/system/halo2/test/ipa/native.rs (100%) rename {src => snark-verifier/src}/system/halo2/test/kzg.rs (100%) rename {src => snark-verifier/src}/system/halo2/test/kzg/evm.rs (100%) rename {src => snark-verifier/src}/system/halo2/test/kzg/halo2.rs (99%) rename {src => snark-verifier/src}/system/halo2/test/kzg/native.rs (100%) rename {src => snark-verifier/src}/system/halo2/transcript.rs (100%) rename {src => snark-verifier/src}/system/halo2/transcript/evm.rs (100%) rename {src => snark-verifier/src}/system/halo2/transcript/halo2.rs (100%) rename {src => snark-verifier/src}/util.rs (100%) rename {src => snark-verifier/src}/util/arithmetic.rs (100%) rename {src => snark-verifier/src}/util/hash.rs (100%) rename {src => snark-verifier/src}/util/hash/poseidon.rs (100%) rename {src => snark-verifier/src}/util/msm.rs (100%) rename {src => snark-verifier/src}/util/poly.rs (100%) rename {src => snark-verifier/src}/util/protocol.rs (100%) rename {src => snark-verifier/src}/util/transcript.rs (100%) rename {src => snark-verifier/src}/verifier.rs (100%) rename {src => snark-verifier/src}/verifier/plonk.rs (100%) delete mode 100644 src/system/halo2/test/README.md diff --git a/Cargo.toml b/Cargo.toml index 10d8f406..17890512 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,89 +1,8 @@ -[package] -name = "plonk_verifier" -version = "0.1.0" -edition = "2021" - -[dependencies] -itertools = "0.10.3" -lazy_static = "1.4.0" -num-bigint = "0.4.3" -num-integer = "0.1.45" -num-traits = "0.2.15" -rand = "0.8" -rand_chacha = "0.3.1" -hex = "0.4" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -bincode = "1.3.3" -ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } - -# Use halo2-base as non-optional dependency because it re-exports halo2_proofs, halo2curves, and poseidon, using different repos based on feature flag "halo2-axiom" or "halo2-pse" -halo2-base = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "experiment/optimizations", default-features = false } - -# parallel -rayon = { version = "1.5.3", optional = true } - -# loader_evm -ethereum_types = { package = "ethereum-types", version = "0.14", default-features = false, features = ["std"], optional = true } -sha3 = { version = "0.10", optional = true } -revm = { version = "2.3.1", optional = true } -bytes = { version = "1.2", optional = true } -rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } - -# loader_halo2 -halo2-ecc = { git = "ssh://github.com/axiom-crypto/halo2-lib-working.git", branch = "experiment/optimizations", default-features = false, optional = true } - -# zkevm benchmarks -zkevm-circuits = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", features = ["test"], optional = true } -bus-mapping = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } -eth-types = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } -mock = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } - -[dev-dependencies] -ark-std = { version = "0.3.0", features = ["print-trace"] } -paste = "1.0.7" -pprof = { version = "0.11", features = ["criterion", "flamegraph"] } -criterion = "0.4" -criterion-macro = "0.4" -# loader_evm -crossterm = { version = "0.25" } -tui = { version = "0.19", default-features = false, features = ["crossterm"] } - -[features] -default = ["loader_evm", "loader_halo2", "zkevm", "halo2-pse", "halo2-ecc?/jemallocator"] -display = ["halo2-base/display", "halo2-ecc?/display", "ark-std"] -# EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo -halo2-pse = ["halo2-base/halo2-pse", "halo2-ecc?/halo2-pse"] -halo2-axiom = ["halo2-base/halo2-axiom", "halo2-ecc?/halo2-axiom"] - -parallel = ["dep:rayon"] - -loader_evm = ["dep:ethereum_types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] -loader_halo2 = ["halo2-ecc"] - -zkevm = ["dep:zkevm-circuits", "dep:bus-mapping", "dep:mock", "dep:eth-types"] - -[[example]] -name = "evm-verifier" -required-features = ["loader_evm"] - -[[example]] -name = "evm-verifier-with-accumulator" -required-features = ["loader_halo2", "loader_evm"] - -[[example]] -name = "recursion" -required-features = ["loader_halo2"] - -[[bench]] -name = "standard_plonk" -required-features = ["loader_halo2"] -harness = false - -[[bench]] -name = "zkevm" -required-features = ["loader_halo2", "zkevm", "halo2-pse"] -harness = false +[workspace] +members = [ + "snark-verifier", + "snark-verifier-sdk", +] [profile.dev] opt-level = 3 @@ -115,12 +34,19 @@ incremental = false inherits = "release" debug = true -[patch."ssh://github.com/axiom-crypto/halo2-lib-working.git"] -halo2-base = { path = "../halo2-lib-working/halo2-base" } -halo2-ecc = { path = "../halo2-lib-working/halo2-ecc" } +[patch."ssh://github.com/axiom-crypto/axiom-core-working.git"] +halo2-base = { path = "../axiom-core-working/halo2-lib/halo2-base" } +halo2-ecc = { path = "../axiom-core-working/halo2-lib/halo2-ecc" } -[patch."https://github.com/privacy-scaling-explorations/halo2curves.git"] +[patch."https://github.com/axiom-crypto/halo2.git"] +halo2_proofs = { path = "../halo2/halo2_proofs" } halo2curves = { path = "../halo2/arithmetic/curves" } +poseidon = { path = "../halo2/primitives/poseidon" } + +# patch for now because PSE/halo2 has not yet updated halo2curves version, unnecessary if halo2_proofs is using latest halo2curves with Fq12 public +[patch."https://github.com/privacy-scaling-explorations/halo2curves.git"] +halo2curves = { path = "../halo2/arithmetic/curves" } +# patch just because we cannot patch the same repo with different tag: serialization is already in latest PSE/halo2 but not in v2022_10_22 [patch."https://github.com/privacy-scaling-explorations/halo2.git"] -halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/add-serialization" } +halo2_proofs = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/add-serialization" } \ No newline at end of file diff --git a/snark-verifier-sdk/Cargo.toml b/snark-verifier-sdk/Cargo.toml new file mode 100644 index 00000000..fd5c2344 --- /dev/null +++ b/snark-verifier-sdk/Cargo.toml @@ -0,0 +1,66 @@ +[package] +name = "snark-verifier-sdk" +version = "0.0.1" +edition = "2021" + +[dependencies] +itertools = "0.10.3" +lazy_static = "1.4.0" +num-bigint = "0.4.3" +num-integer = "0.1.45" +num-traits = "0.2.15" +rand = "0.8" +rand_chacha = "0.3.1" +hex = "0.4" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +bincode = "1.3.3" +ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } + +halo2-base = { git = "ssh://github.com/axiom-crypto/axiom-core-working.git", branch = "experiment/optimizations", default-features = false } +snark-verifier = { path = "../snark-verifier", default-features = false } + +# loader_evm +ethereum-types = { version = "0.14", default-features = false, features = ["std"], optional = true } +# sha3 = { version = "0.10", optional = true } +# revm = { version = "2.3.1", optional = true } +# bytes = { version = "1.2", optional = true } +# rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } + +# zkevm benchmarks +zkevm-circuits = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", features = ["test"], optional = true } +bus-mapping = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } +eth-types = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } +mock = { git = "https://github.com/jonathanpwang/zkevm-circuits.git", branch = "bench-12-04", optional = true } + +[dev-dependencies] +ark-std = { version = "0.3.0", features = ["print-trace"] } +paste = "1.0.7" +pprof = { version = "0.11", features = ["criterion", "flamegraph"] } +criterion = "0.4" +criterion-macro = "0.4" +# loader_evm +crossterm = { version = "0.25" } +tui = { version = "0.19", default-features = false, features = ["crossterm"] } + +[features] +default = ["loader_evm", "loader_halo2", "zkevm", "halo2-pse", "halo2-base/jemallocator"] +display = ["snark-verifier/display"] +loader_evm = ["snark-verifier/loader_evm", "dep:ethereum-types"] +loader_halo2 = ["snark-verifier/loader_halo2"] +parallel = ["snark-verifier/parallel"] +# EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo +halo2-pse = ["snark-verifier/halo2-pse"] +halo2-axiom = ["snark-verifier/halo2-axiom"] + +zkevm = ["dep:zkevm-circuits", "dep:bus-mapping", "dep:mock", "dep:eth-types"] + +[[bench]] +name = "standard_plonk" +required-features = ["loader_halo2"] +harness = false + +[[bench]] +name = "zkevm" +required-features = ["loader_halo2", "zkevm", "halo2-pse", "halo2-base/jemallocator"] +harness = false \ No newline at end of file diff --git a/benches/standard_plonk.rs b/snark-verifier-sdk/benches/standard_plonk.rs similarity index 88% rename from benches/standard_plonk.rs rename to snark-verifier-sdk/benches/standard_plonk.rs index 189d6bdf..c72072eb 100644 --- a/benches/standard_plonk.rs +++ b/snark-verifier-sdk/benches/standard_plonk.rs @@ -9,19 +9,18 @@ use halo2_proofs::{ halo2curves::bn256::Bn256, poly::{commitment::Params, kzg::commitment::ParamsKZG}, }; -use plonk_verifier::{ - loader::native::NativeLoader, - sdk::{ - self, gen_pk, - halo2::{ - aggregation::AggregationCircuit, gen_proof_shplonk, gen_snark_shplonk, - PoseidonTranscript, POSEIDON_SPEC, - }, - }, -}; use rand::rngs::OsRng; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; +use snark_verifier::loader::native::NativeLoader; +use snark_verifier_sdk::{ + gen_pk, + halo2::{ + aggregation::AggregationCircuit, gen_proof_shplonk, gen_snark_shplonk, PoseidonTranscript, + POSEIDON_SPEC, + }, + Snark, +}; mod application { use super::halo2_curves::bn256::Fr; @@ -30,9 +29,8 @@ mod application { plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, poly::Rotation, }; - use halo2_base::halo2_proofs::plonk::Assigned; - use plonk_verifier::sdk::CircuitExt; use rand::RngCore; + use snark_verifier_sdk::CircuitExt; #[derive(Clone, Copy)] pub struct StandardPlonkConfig { @@ -157,13 +155,9 @@ mod application { 0, Value::known(Assigned::Trivial(self.0)), )?; - region.assign_fixed(config.q_a, 0, Assigned::Trivial(-Fr::one())); + region.assign_fixed(config.q_a, 0, -Fr::one()); - region.assign_advice( - config.a, - 1, - Value::known(Assigned::Trivial(-Fr::from(5u64))), - )?; + region.assign_advice(config.a, 1, Value::known(-Fr::from(5u64)))?; for (idx, column) in (1..).zip([ config.q_a, config.q_b, @@ -171,14 +165,10 @@ mod application { config.q_ab, config.constant, ]) { - region.assign_fixed(column, 1, Assigned::Trivial(Fr::from(idx as u64))); + region.assign_fixed(column, 1, Fr::from(idx as u64)); } - let a = region.assign_advice( - config.a, - 2, - Value::known(Assigned::Trivial(Fr::one())), - )?; + let a = region.assign_advice(config.a, 2, Value::known(Fr::one()))?; a.copy_advice(&mut region, config.b, 3); a.copy_advice(&mut region, config.c, 4); } @@ -193,7 +183,7 @@ mod application { fn gen_application_snark( params: &ParamsKZG, transcript: &mut PoseidonTranscript>, -) -> sdk::Snark { +) -> Snark { let circuit = application::StandardPlonk::rand(OsRng); let pk = gen_pk(params, &circuit, None); diff --git a/benches/zkevm.rs b/snark-verifier-sdk/benches/zkevm.rs similarity index 88% rename from benches/zkevm.rs rename to snark-verifier-sdk/benches/zkevm.rs index 8fb4ab26..a35573b1 100644 --- a/benches/zkevm.rs +++ b/snark-verifier-sdk/benches/zkevm.rs @@ -1,28 +1,24 @@ -use std::env::{set_var, var}; -use std::path::Path; - use ark_std::{end_timer, start_timer}; use halo2_base::halo2_proofs; use halo2_base::utils::fs::gen_srs; use halo2_proofs::halo2curves::bn256::Fr; -use plonk_verifier::sdk::halo2::aggregation::load_verify_circuit_degree; -use plonk_verifier::{ - loader::native::NativeLoader, - sdk::{ - self, - evm::{ - evm_verify, gen_evm_proof_gwc, gen_evm_proof_shplonk, gen_evm_verifier_gwc, - gen_evm_verifier_shplonk, - }, - gen_pk, - halo2::{ - aggregation::AggregationCircuit, gen_proof_gwc, gen_proof_shplonk, gen_snark_shplonk, - PoseidonTranscript, POSEIDON_SPEC, - }, - }, -}; use rand::SeedableRng; use rand_chacha::ChaCha20Rng; +use snark_verifier::loader::native::NativeLoader; +use snark_verifier_sdk::{ + self, + evm::{ + evm_verify, gen_evm_proof_gwc, gen_evm_proof_shplonk, gen_evm_verifier_gwc, + gen_evm_verifier_shplonk, + }, + gen_pk, + halo2::{ + aggregation::load_verify_circuit_degree, aggregation::AggregationCircuit, gen_proof_gwc, + gen_proof_shplonk, gen_snark_shplonk, PoseidonTranscript, POSEIDON_SPEC, + }, +}; +use std::env::{set_var, var}; +use std::path::Path; use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; diff --git a/configs/bench_zkevm.config b/snark-verifier-sdk/configs/bench_zkevm.config similarity index 100% rename from configs/bench_zkevm.config rename to snark-verifier-sdk/configs/bench_zkevm.config diff --git a/snark-verifier-sdk/configs/example_evm_accumulator.config b/snark-verifier-sdk/configs/example_evm_accumulator.config new file mode 100644 index 00000000..fcda49a0 --- /dev/null +++ b/snark-verifier-sdk/configs/example_evm_accumulator.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":21,"num_advice":5,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/configs/verify_circuit.config b/snark-verifier-sdk/configs/verify_circuit.config similarity index 100% rename from configs/verify_circuit.config rename to snark-verifier-sdk/configs/verify_circuit.config diff --git a/src/sdk/evm.rs b/snark-verifier-sdk/src/evm.rs similarity index 87% rename from src/sdk/evm.rs rename to snark-verifier-sdk/src/evm.rs index e6154ee2..f9d6215e 100644 --- a/src/sdk/evm.rs +++ b/snark-verifier-sdk/src/evm.rs @@ -1,36 +1,32 @@ -use super::CircuitExt; -use crate::{ - halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, - plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, - poly::{ - commitment::{Params, ParamsProver, Prover, Verifier}, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - msm::DualMSM, - multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, - strategy::{AccumulatorStrategy, GuardKZG}, - }, - VerificationStrategy, +use super::{CircuitExt, Plonk}; +use ethereum_types::Address; +use halo2_base::halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::{ + commitment::{Params, ParamsProver, Prover, Verifier}, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + msm::DualMSM, + multiopen::{ProverGWC, ProverSHPLONK, VerifierGWC, VerifierSHPLONK}, + strategy::{AccumulatorStrategy, GuardKZG}, }, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, + VerificationStrategy, }, - pcs::kzg::{Bdfg21, Gwc19, Kzg}, - sdk::Plonk, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; -use crate::{ +use itertools::Itertools; +use rand::Rng; +use snark_verifier::{ loader::evm::{compile_yul, encode_calldata, EvmLoader, ExecutorBuilder}, pcs::{ - kzg::{KzgAccumulator, KzgDecidingKey, KzgSuccinctVerifyingKey}, + kzg::{Bdfg21, Gwc19, Kzg, KzgAccumulator, KzgDecidingKey, KzgSuccinctVerifyingKey}, Decider, MultiOpenScheme, PolynomialCommitmentScheme, }, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::PlonkVerifier, }; -use ethereum_types::Address; -use itertools::Itertools; -use rand::Rng; use std::{fs, io, path::Path, rc::Rc}; /// Generates a proof for evm verification using either SHPLONK or GWC proving method. Uses Keccak for Fiat-Shamir. diff --git a/src/sdk/halo2.rs b/snark-verifier-sdk/src/halo2.rs similarity index 97% rename from src/sdk/halo2.rs rename to snark-verifier-sdk/src/halo2.rs index 40a65ad7..b453b67e 100644 --- a/src/sdk/halo2.rs +++ b/snark-verifier-sdk/src/halo2.rs @@ -1,17 +1,7 @@ use super::{read_instances, write_instances, CircuitExt, Snark, SnarkWitness}; -use crate::cost::CostEstimation; -use crate::halo2_proofs; -use crate::pcs::MultiOpenScheme; -use crate::{ - loader::native::NativeLoader, - pcs, - poseidon::Spec, - system::halo2::{compile, Config}, - util::transcript::TranscriptWrite, - verifier::PlonkProof, -}; #[cfg(feature = "display")] use ark_std::{end_timer, start_timer}; +use halo2_base::{halo2_proofs, poseidon::Spec}; use halo2_proofs::{ circuit::Layouter, dev::MockProver, @@ -37,6 +27,14 @@ use halo2_proofs::{ use itertools::Itertools; use lazy_static::lazy_static; use rand::Rng; +use snark_verifier::{ + cost::CostEstimation, + loader::native::NativeLoader, + pcs::{self, MultiOpenScheme}, + system::halo2::{compile, Config}, + util::transcript::TranscriptWrite, + verifier::PlonkProof, +}; use std::{fs, iter, marker::PhantomData, path::Path}; pub mod aggregation; @@ -48,7 +46,15 @@ const R_F: usize = 8; const R_P: usize = 60; pub type PoseidonTranscript = - crate::system::halo2::transcript::halo2::PoseidonTranscript; + snark_verifier::system::halo2::transcript::halo2::PoseidonTranscript< + G1Affine, + L, + S, + T, + RATE, + R_F, + R_P, + >; lazy_static! { pub static ref POSEIDON_SPEC: Spec = Spec::new(R_F, R_P); diff --git a/src/sdk/halo2/aggregation.rs b/snark-verifier-sdk/src/halo2/aggregation.rs similarity index 95% rename from src/sdk/halo2/aggregation.rs rename to snark-verifier-sdk/src/halo2/aggregation.rs index d5400a41..81cfd12e 100644 --- a/src/sdk/halo2/aggregation.rs +++ b/snark-verifier-sdk/src/halo2/aggregation.rs @@ -1,25 +1,28 @@ -use crate::halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; -use crate::halo2_proofs::{ +use crate::{Plonk, BITS, LIMBS}; +#[cfg(feature = "display")] +use ark_std::{end_timer, start_timer}; +use halo2_base::halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, plonk::{self, Circuit, Column, ConstraintSystem, Instance, Selector}, poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, }; -use crate::{ - loader::{self, native::NativeLoader}, +use halo2_base::{Context, ContextParams}; +use itertools::Itertools; +use rand::Rng; +use snark_verifier::{ + loader::{ + self, + halo2::halo2_ecc::{self, ecc::EccChip}, + native::NativeLoader, + }, pcs::{ kzg::{Bdfg21, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey}, AccumulationScheme, AccumulationSchemeProver, MultiOpenScheme, PolynomialCommitmentScheme, }, - sdk::{Plonk, BITS, LIMBS}, util::arithmetic::fe_to_limbs, verifier::PlonkVerifier, }; -#[cfg(feature = "display")] -use ark_std::{end_timer, start_timer}; -use halo2_base::{Context, ContextParams}; -use halo2_ecc::ecc::EccChip; -use itertools::Itertools; -use rand::Rng; use std::{fs::File, rc::Rc}; use super::{CircuitExt, PoseidonTranscript, Snark, SnarkWitness, POSEIDON_SPEC}; @@ -167,6 +170,8 @@ impl AggregationConfig { } /// Aggregation circuit that does not re-expose any public inputs from aggregated snarks +/// +/// This is mostly a reference implementation. In practice one will probably need to re-implement the circuit for one's particular use case with specific instance logic. #[derive(Clone)] pub struct AggregationCircuit { svk: Svk, @@ -260,7 +265,7 @@ impl CircuitExt for AggregationCircuit { } fn selectors(config: &Self::Config) -> Vec { - config.gate().basic_gates.iter().map(|gate| gate.q_enable).collect() + config.gate().basic_gates[0].iter().map(|gate| gate.q_enable).collect() } } @@ -312,7 +317,7 @@ impl Circuit for AggregationCircuit { region, ContextParams { max_rows: config.gate().max_rows, - num_advice: vec![config.gate().num_advice], + num_context_ids: 1, fixed_columns: config.gate().constants.clone(), }, ); diff --git a/src/sdk.rs b/snark-verifier-sdk/src/lib.rs similarity index 97% rename from src/sdk.rs rename to snark-verifier-sdk/src/lib.rs index a2585f79..e46b704d 100644 --- a/src/sdk.rs +++ b/snark-verifier-sdk/src/lib.rs @@ -1,8 +1,6 @@ -#![allow(clippy::let_and_return)] -use crate::halo2_proofs; -use crate::{pcs::kzg::LimbsEncoding, verifier, Protocol}; #[cfg(feature = "display")] use ark_std::{end_timer, start_timer}; +use halo2_base::halo2_proofs; use halo2_proofs::{ circuit::Value, halo2curves::{ @@ -13,6 +11,7 @@ use halo2_proofs::{ poly::kzg::commitment::ParamsKZG, }; use itertools::Itertools; +use snark_verifier::{pcs::kzg::LimbsEncoding, verifier, Protocol}; use std::{ fs::{self, File}, io::{BufReader, BufWriter}, diff --git a/snark-verifier/Cargo.toml b/snark-verifier/Cargo.toml new file mode 100644 index 00000000..146c69fb --- /dev/null +++ b/snark-verifier/Cargo.toml @@ -0,0 +1,61 @@ +[package] +name = "snark-verifier" +version = "0.1.0" +edition = "2021" + +[dependencies] +itertools = "0.10.3" +lazy_static = "1.4.0" +num-bigint = "0.4.3" +num-integer = "0.1.45" +num-traits = "0.2.15" +hex = "0.4" +rand = "0.8" + +# Use halo2-base as non-optional dependency because it re-exports halo2_proofs, halo2curves, and poseidon, using different repos based on feature flag "halo2-axiom" or "halo2-pse" +halo2-base = { git = "ssh://github.com/axiom-crypto/axiom-core-working.git", branch = "experiment/optimizations", default-features = false } + +# parallel +rayon = { version = "1.5.3", optional = true } + +# loader_evm +ethereum-types = { version = "0.14", default-features = false, features = ["std"], optional = true } +sha3 = { version = "0.10", optional = true } +revm = { version = "2.3.1", optional = true } +bytes = { version = "1.2", optional = true } +rlp = { version = "0.5", default-features = false, features = ["std"], optional = true } + +# loader_halo2 +halo2-ecc = { git = "ssh://github.com/axiom-crypto/axiom-core-working.git", branch = "experiment/optimizations", default-features = false, optional = true } + +[dev-dependencies] +ark-std = { version = "0.3.0", features = ["print-trace"] } +paste = "1.0.7" +rand_chacha = "0.3.1" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +# loader_evm +crossterm = { version = "0.25" } +tui = { version = "0.19", default-features = false, features = ["crossterm"] } + +[features] +default = ["loader_evm", "loader_halo2", "halo2-pse"] +display = ["halo2-base/display", "halo2-ecc?/display"] +loader_evm = ["dep:ethereum-types", "dep:sha3", "dep:revm", "dep:bytes", "dep:rlp"] +loader_halo2 = ["halo2-ecc"] +parallel = ["dep:rayon"] +# EXACTLY one of halo2-pse / halo2-axiom should always be turned on; not sure how to enforce this with Cargo +halo2-pse = ["halo2-base/halo2-pse", "halo2-ecc?/halo2-pse"] +halo2-axiom = ["halo2-base/halo2-axiom", "halo2-ecc?/halo2-axiom"] + +[[example]] +name = "evm-verifier" +required-features = ["loader_evm"] + +[[example]] +name = "evm-verifier-with-accumulator" +required-features = ["loader_halo2", "loader_evm"] + +[[example]] +name = "recursion" +required-features = ["loader_halo2"] \ No newline at end of file diff --git a/snark-verifier/configs/example_evm_accumulator.config b/snark-verifier/configs/example_evm_accumulator.config new file mode 100644 index 00000000..fcda49a0 --- /dev/null +++ b/snark-verifier/configs/example_evm_accumulator.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":21,"num_advice":5,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":20,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/configs/example_evm_accumulator.config b/snark-verifier/configs/verify_circuit.config similarity index 100% rename from configs/example_evm_accumulator.config rename to snark-verifier/configs/verify_circuit.config diff --git a/examples/README.md b/snark-verifier/examples/README.md similarity index 77% rename from examples/README.md rename to snark-verifier/examples/README.md index 443dc67e..67cd4267 100644 --- a/examples/README.md +++ b/snark-verifier/examples/README.md @@ -1,4 +1,4 @@ -In `plonk-verifier` root directory: +In `snark-verifier` root directory: 1. Create `./configs/verify_circuit.config` diff --git a/examples/evm-verifier-with-accumulator.rs b/snark-verifier/examples/evm-verifier-with-accumulator.rs similarity index 89% rename from examples/evm-verifier-with-accumulator.rs rename to snark-verifier/examples/evm-verifier-with-accumulator.rs index 84ed5003..cdb5bb16 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/snark-verifier/examples/evm-verifier-with-accumulator.rs @@ -1,8 +1,5 @@ -use ark_std::{end_timer, start_timer}; use ethereum_types::Address; use halo2_base::halo2_proofs; -use halo2_base::halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK}; -use halo2_proofs::halo2curves as halo2_curves; use halo2_proofs::{ dev::MockProver, halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, @@ -19,39 +16,32 @@ use halo2_proofs::{ transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer}, }; use itertools::Itertools; -use plonk_verifier::pcs::kzg::Bdfg21; -use plonk_verifier::sdk::evm::{ - evm_verify, gen_evm_proof_gwc, gen_evm_proof_shplonk, gen_evm_verifier_gwc, - gen_evm_verifier_shplonk, -}; -use plonk_verifier::{ +use rand::rngs::OsRng; +use snark_verifier::{ loader::{ evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, native::NativeLoader, }, pcs::kzg::{Gwc19, Kzg, KzgAs, LimbsEncoding}, - sdk::CircuitExt, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::{self, PlonkVerifier}, }; -use rand::rngs::OsRng; -use std::path::Path; use std::{io::Cursor, rc::Rc}; const LIMBS: usize = 3; const BITS: usize = 88; -type Pcs = Kzg; +type Pcs = Kzg; type As = KzgAs; type Plonk = verifier::Plonk>; mod application { - use super::halo2_curves::bn256::Fr; use super::halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner, Value}, plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Fixed, Instance}, poly::Rotation, }; + use super::Fr; use halo2_base::halo2_proofs::plonk::Assigned; use rand::RngCore; @@ -210,29 +200,19 @@ mod application { } mod aggregation { - use super::halo2_curves::bn256::{Bn256, Fq, Fr, G1Affine}; use super::halo2_proofs::{ circuit::{Cell, Layouter, SimpleFloorPlanner, Value}, plonk::{self, Circuit, Column, ConstraintSystem, Instance, Selector}, poly::{commitment::ParamsProver, kzg::commitment::ParamsKZG}, }; use super::{As, Plonk, BITS, LIMBS}; + use super::{Bn256, Fq, Fr, G1Affine}; use ark_std::{end_timer, start_timer}; use halo2_base::{Context, ContextParams}; use halo2_ecc::ecc::EccChip; - /* - use halo2_wrong_ecc::{ - integer::rns::Rns, - maingate::{ - MainGate, MainGateConfig, MainGateInstructions, RangeChip, RangeConfig, - RangeInstructions, RegionCtx, - }, - EccConfig, - }; - */ use itertools::Itertools; - use plonk_verifier::sdk::CircuitExt; - use plonk_verifier::{ + use rand::rngs::OsRng; + use snark_verifier::{ loader::{self, native::NativeLoader}, pcs::{ kzg::{KzgAccumulator, KzgSuccinctVerifyingKey}, @@ -243,7 +223,6 @@ mod aggregation { verifier::PlonkVerifier, Protocol, }; - use rand::rngs::OsRng; use std::{fs::File, rc::Rc}; const T: usize = 5; @@ -254,7 +233,6 @@ mod aggregation { type Svk = KzgSuccinctVerifyingKey; type BaseFieldEccChip = halo2_ecc::ecc::BaseFieldEccChip; type Halo2Loader<'a> = loader::halo2::Halo2Loader<'a, G1Affine, BaseFieldEccChip>; - // type BaseFieldEccChip = halo2_wrong_ecc::BaseFieldEccChip; pub type PoseidonTranscript = system::halo2::transcript::halo2::PoseidonTranscript; @@ -448,33 +426,21 @@ mod aggregation { pub fn as_proof(&self) -> Value<&[u8]> { self.as_proof.as_ref().map(Vec::as_slice) } - } - impl CircuitExt for AggregationCircuit { - fn num_instance() -> Vec { + pub fn num_instance() -> Vec { // [..lhs, ..rhs] vec![4 * LIMBS] } - fn instances(&self) -> Vec> { + pub fn instances(&self) -> Vec> { vec![self.instances.clone()] } - fn accumulator_indices() -> Option> { - Some((0..4 * LIMBS).map(|idx| (0, idx)).collect()) - } - - fn selectors(config: &Self::Config) -> Vec { - config - .base_field_config - .range - .gate - .basic_gates - .iter() - .map(|gate| gate.q_enable) - .collect() + pub fn accumulator_indices() -> Vec<(usize, usize)> { + (0..4 * LIMBS).map(|idx| (0, idx)).collect() } } + impl Circuit for AggregationCircuit { type Config = AggregationConfig; type FloorPlanner = SimpleFloorPlanner; @@ -489,10 +455,10 @@ mod aggregation { } fn configure(meta: &mut plonk::ConstraintSystem) -> Self::Config { - let path = std::env::var("VERIFY_CONFIG") - .unwrap_or_else(|_| "configs/verify_circuit.config".to_owned()); + let path = std::env::var("VERIFY_CONFIG").unwrap(); let params: AggregationConfigParams = serde_json::from_reader( - File::open(path.as_str()).unwrap_or_else(|err| panic!("{err:?}")), + File::open(path.as_str()) + .unwrap_or_else(|err| panic!("Path {path} does not exist: {err:?}")), ) .unwrap(); @@ -521,7 +487,7 @@ mod aggregation { region, ContextParams { max_rows, - num_advice: vec![config.base_field_config.range.gate.num_advice], + num_context_ids: 1, fixed_columns: config.base_field_config.range.gate.constants.clone(), }, ); @@ -538,7 +504,7 @@ mod aggregation { #[cfg(feature = "display")] println!("Total advice cells: {}", loader.ctx().total_advice); #[cfg(feature = "display")] - println!("Advice columns used: {}", loader.ctx().advice_alloc[0][0].0 + 1); + println!("Advice columns used: {}", loader.ctx().advice_alloc[0].0 + 1); let instances: Vec<_> = lhs .x @@ -586,11 +552,9 @@ fn gen_proof< MockProver::run(params.k(), &circuit, instances.clone()).unwrap().assert_satisfied(); let instances = instances.iter().map(|instances| instances.as_slice()).collect_vec(); - - let proof_time = start_timer!(|| "Create proof"); let proof = { let mut transcript = TW::init(Vec::new()); - create_proof::, ProverSHPLONK<_>, _, _, TW, _>( + create_proof::, ProverGWC<_>, _, _, TW, _>( params, pk, &[circuit], @@ -601,12 +565,11 @@ fn gen_proof< .unwrap(); transcript.finalize() }; - end_timer!(proof_time); let accept = { let mut transcript = TR::init(Cursor::new(proof.clone())); - VerificationStrategy::<_, VerifierSHPLONK<_>>::finalize( - verify_proof::<_, VerifierSHPLONK<_>, _, TR, _>( + VerificationStrategy::<_, VerifierGWC<_>>::finalize( + verify_proof::<_, VerifierGWC<_>, _, TR, _>( params.verifier_params(), pk.get_vk(), AccumulatorStrategy::new(params.verifier_params()), @@ -640,6 +603,49 @@ fn gen_application_snark(params: &ParamsKZG) -> aggregation::Snark { aggregation::Snark::new(protocol, circuit.instances(), proof) } +fn gen_aggregation_evm_verifier( + params: &ParamsKZG, + vk: &VerifyingKey, + num_instance: Vec, + accumulator_indices: Vec<(usize, usize)>, +) -> Vec { + let svk = params.get_g()[0].into(); + let dk = (params.g2(), params.s_g2()).into(); + let protocol = compile( + params, + vk, + Config::kzg() + .with_num_instance(num_instance.clone()) + .with_accumulator_indices(Some(accumulator_indices)), + ); + + let loader = EvmLoader::new::(); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); + + let instances = transcript.load_instances(num_instance); + let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); + + evm::compile_yul(&loader.yul_code()) +} + +fn evm_verify(deployment_code: Vec, instances: Vec>, proof: Vec) { + let calldata = encode_calldata(&instances, &proof); + let success = { + let mut evm = ExecutorBuilder::default().with_gas_limit(u64::MAX.into()).build(); + + let caller = Address::from_low_u64_be(0xfe); + let verifier = evm.deploy(caller, deployment_code.into(), 0.into()).address.unwrap(); + let result = evm.call_raw(caller, verifier, calldata.into(), 0.into()); + + dbg!(result.gas_used); + + !result.reverted + }; + assert!(success); +} + fn main() { std::env::set_var("VERIFY_CONFIG", "./configs/example_evm_accumulator.config"); let params = halo2_base::utils::fs::gen_srs(21); @@ -652,18 +658,18 @@ fn main() { let snarks = [(); 3].map(|_| gen_application_snark(¶ms_app)); let agg_circuit = aggregation::AggregationCircuit::new(¶ms, snarks); let pk = gen_pk(¶ms, &agg_circuit); - let deployment_code = gen_evm_verifier_shplonk::( + let deployment_code = gen_aggregation_evm_verifier( ¶ms, pk.get_vk(), - Some(Path::new("evm_verifier.yul")), + aggregation::AggregationCircuit::num_instance(), + aggregation::AggregationCircuit::accumulator_indices(), ); - let proof = gen_evm_proof_shplonk( + let proof = gen_proof::<_, _, EvmTranscript, EvmTranscript>( ¶ms, &pk, agg_circuit.clone(), agg_circuit.instances(), - &mut OsRng, ); evm_verify(deployment_code, agg_circuit.instances(), proof); } diff --git a/examples/evm-verifier.rs b/snark-verifier/examples/evm-verifier.rs similarity index 99% rename from examples/evm-verifier.rs rename to snark-verifier/examples/evm-verifier.rs index 9f66ed86..d7a1f0c8 100644 --- a/examples/evm-verifier.rs +++ b/snark-verifier/examples/evm-verifier.rs @@ -20,13 +20,13 @@ use halo2_proofs::{ transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; use itertools::Itertools; -use plonk_verifier::{ +use rand::{rngs::OsRng, RngCore}; +use snark_verifier::{ loader::evm::{self, encode_calldata, EvmLoader, ExecutorBuilder}, pcs::kzg::{Gwc19, Kzg}, system::halo2::{compile, transcript::evm::EvmTranscript, Config}, verifier::{self, PlonkVerifier}, }; -use rand::{rngs::OsRng, RngCore}; use std::rc::Rc; type Plonk = verifier::Plonk>; diff --git a/examples/recursion.rs b/snark-verifier/examples/recursion.rs similarity index 98% rename from examples/recursion.rs rename to snark-verifier/examples/recursion.rs index 86c0d415..569d6a12 100644 --- a/examples/recursion.rs +++ b/snark-verifier/examples/recursion.rs @@ -10,7 +10,7 @@ use halo2_proofs::{ halo2curves::{ bn256::{Bn256, Fq, Fr, G1Affine}, group::ff::Field, - CurveAffine, FieldExt, + FieldExt, }, plonk::{ self, create_proof, keygen_pk, keygen_vk, Circuit, ConstraintSystem, Error, ProvingKey, @@ -19,7 +19,7 @@ use halo2_proofs::{ poly::{ commitment::ParamsProver, kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, + commitment::ParamsKZG, multiopen::{ProverGWC, VerifierGWC}, strategy::AccumulatorStrategy, }, @@ -27,7 +27,8 @@ use halo2_proofs::{ }, }; use itertools::Itertools; -use plonk_verifier::{ +use rand_chacha::rand_core::OsRng; +use snark_verifier::{ loader::{self, native::NativeLoader, Loader, ScalarLoader}, pcs::{ kzg::{Gwc19, Kzg, KzgAccumulator, KzgAs, KzgSuccinctVerifyingKey, LimbsEncoding}, @@ -41,10 +42,6 @@ use plonk_verifier::{ verifier::{self, PlonkProof, PlonkVerifier}, Protocol, }; -use rand_chacha::{ - rand_core::{OsRng, SeedableRng}, - ChaCha20Rng, -}; use std::{fs, iter, marker::PhantomData, rc::Rc}; use crate::recursion::AggregationConfigParams; @@ -67,7 +64,7 @@ type PoseidonTranscript = mod common { use super::*; use halo2_proofs::{plonk::verify_proof, poly::commitment::Params}; - use plonk_verifier::{cost::CostEstimation, util::transcript::TranscriptWrite}; + use snark_verifier::{cost::CostEstimation, util::transcript::TranscriptWrite}; pub fn poseidon>( loader: &L, @@ -354,7 +351,7 @@ mod recursion { }; use halo2_ecc::ecc::EccChip; use halo2_proofs::plonk::{Column, Instance}; - use plonk_verifier::loader::halo2::{EccInstructions, IntegerInstructions}; + use snark_verifier::loader::halo2::{EccInstructions, IntegerInstructions}; use super::*; @@ -663,7 +660,7 @@ mod recursion { region, ContextParams { max_rows, - num_advice: vec![config.base_field_config.range.gate.num_advice], + num_context_ids: 1, fixed_columns: config.base_field_config.range.gate.constants.clone(), }, ); @@ -813,11 +810,7 @@ mod recursion { } fn selectors(config: &Self::Config) -> Vec { - config - .base_field_config - .range - .gate - .basic_gates + config.base_field_config.range.gate.basic_gates[0] .iter() .map(|gate| gate.q_enable) .collect() diff --git a/src/cost.rs b/snark-verifier/src/cost.rs similarity index 100% rename from src/cost.rs rename to snark-verifier/src/cost.rs diff --git a/src/lib.rs b/snark-verifier/src/lib.rs similarity index 99% rename from src/lib.rs rename to snark-verifier/src/lib.rs index 3878c019..a6e4b3e6 100644 --- a/src/lib.rs +++ b/snark-verifier/src/lib.rs @@ -5,7 +5,6 @@ pub mod cost; pub mod loader; pub mod pcs; -pub mod sdk; pub mod system; pub mod util; pub mod verifier; diff --git a/src/loader.rs b/snark-verifier/src/loader.rs similarity index 100% rename from src/loader.rs rename to snark-verifier/src/loader.rs diff --git a/src/loader/evm.rs b/snark-verifier/src/loader/evm.rs similarity index 100% rename from src/loader/evm.rs rename to snark-verifier/src/loader/evm.rs diff --git a/src/loader/evm/code.rs b/snark-verifier/src/loader/evm/code.rs similarity index 100% rename from src/loader/evm/code.rs rename to snark-verifier/src/loader/evm/code.rs diff --git a/src/loader/evm/loader.rs b/snark-verifier/src/loader/evm/loader.rs similarity index 100% rename from src/loader/evm/loader.rs rename to snark-verifier/src/loader/evm/loader.rs diff --git a/src/loader/evm/test.rs b/snark-verifier/src/loader/evm/test.rs similarity index 100% rename from src/loader/evm/test.rs rename to snark-verifier/src/loader/evm/test.rs diff --git a/src/loader/evm/test/tui.rs b/snark-verifier/src/loader/evm/test/tui.rs similarity index 100% rename from src/loader/evm/test/tui.rs rename to snark-verifier/src/loader/evm/test/tui.rs diff --git a/src/loader/evm/util.rs b/snark-verifier/src/loader/evm/util.rs similarity index 100% rename from src/loader/evm/util.rs rename to snark-verifier/src/loader/evm/util.rs diff --git a/src/loader/evm/util/executor.rs b/snark-verifier/src/loader/evm/util/executor.rs similarity index 100% rename from src/loader/evm/util/executor.rs rename to snark-verifier/src/loader/evm/util/executor.rs diff --git a/src/loader/halo2.rs b/snark-verifier/src/loader/halo2.rs similarity index 98% rename from src/loader/halo2.rs rename to snark-verifier/src/loader/halo2.rs index 40512fe7..0e84d506 100644 --- a/src/loader/halo2.rs +++ b/snark-verifier/src/loader/halo2.rs @@ -12,7 +12,7 @@ pub use loader::{EcPoint, Halo2Loader, Scalar}; pub use shim::{Context, EccInstructions, IntegerInstructions}; pub use util::Valuetools; -// pub use halo2_wrong_ecc; +pub use halo2_ecc; mod util { use crate::halo2_proofs::circuit::Value; diff --git a/src/loader/halo2/loader.rs b/snark-verifier/src/loader/halo2/loader.rs similarity index 100% rename from src/loader/halo2/loader.rs rename to snark-verifier/src/loader/halo2/loader.rs diff --git a/src/loader/halo2/shim.rs b/snark-verifier/src/loader/halo2/shim.rs similarity index 100% rename from src/loader/halo2/shim.rs rename to snark-verifier/src/loader/halo2/shim.rs diff --git a/src/loader/halo2/test.rs b/snark-verifier/src/loader/halo2/test.rs similarity index 100% rename from src/loader/halo2/test.rs rename to snark-verifier/src/loader/halo2/test.rs diff --git a/src/loader/native.rs b/snark-verifier/src/loader/native.rs similarity index 100% rename from src/loader/native.rs rename to snark-verifier/src/loader/native.rs diff --git a/src/pcs.rs b/snark-verifier/src/pcs.rs similarity index 100% rename from src/pcs.rs rename to snark-verifier/src/pcs.rs diff --git a/src/pcs/ipa.rs b/snark-verifier/src/pcs/ipa.rs similarity index 100% rename from src/pcs/ipa.rs rename to snark-verifier/src/pcs/ipa.rs diff --git a/src/pcs/ipa/accumulation.rs b/snark-verifier/src/pcs/ipa/accumulation.rs similarity index 100% rename from src/pcs/ipa/accumulation.rs rename to snark-verifier/src/pcs/ipa/accumulation.rs diff --git a/src/pcs/ipa/accumulator.rs b/snark-verifier/src/pcs/ipa/accumulator.rs similarity index 100% rename from src/pcs/ipa/accumulator.rs rename to snark-verifier/src/pcs/ipa/accumulator.rs diff --git a/src/pcs/ipa/decider.rs b/snark-verifier/src/pcs/ipa/decider.rs similarity index 100% rename from src/pcs/ipa/decider.rs rename to snark-verifier/src/pcs/ipa/decider.rs diff --git a/src/pcs/ipa/multiopen.rs b/snark-verifier/src/pcs/ipa/multiopen.rs similarity index 100% rename from src/pcs/ipa/multiopen.rs rename to snark-verifier/src/pcs/ipa/multiopen.rs diff --git a/src/pcs/ipa/multiopen/bgh19.rs b/snark-verifier/src/pcs/ipa/multiopen/bgh19.rs similarity index 100% rename from src/pcs/ipa/multiopen/bgh19.rs rename to snark-verifier/src/pcs/ipa/multiopen/bgh19.rs diff --git a/src/pcs/kzg.rs b/snark-verifier/src/pcs/kzg.rs similarity index 100% rename from src/pcs/kzg.rs rename to snark-verifier/src/pcs/kzg.rs diff --git a/src/pcs/kzg/accumulation.rs b/snark-verifier/src/pcs/kzg/accumulation.rs similarity index 100% rename from src/pcs/kzg/accumulation.rs rename to snark-verifier/src/pcs/kzg/accumulation.rs diff --git a/src/pcs/kzg/accumulator.rs b/snark-verifier/src/pcs/kzg/accumulator.rs similarity index 100% rename from src/pcs/kzg/accumulator.rs rename to snark-verifier/src/pcs/kzg/accumulator.rs diff --git a/src/pcs/kzg/decider.rs b/snark-verifier/src/pcs/kzg/decider.rs similarity index 100% rename from src/pcs/kzg/decider.rs rename to snark-verifier/src/pcs/kzg/decider.rs diff --git a/src/pcs/kzg/multiopen.rs b/snark-verifier/src/pcs/kzg/multiopen.rs similarity index 100% rename from src/pcs/kzg/multiopen.rs rename to snark-verifier/src/pcs/kzg/multiopen.rs diff --git a/src/pcs/kzg/multiopen/bdfg21.rs b/snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs similarity index 100% rename from src/pcs/kzg/multiopen/bdfg21.rs rename to snark-verifier/src/pcs/kzg/multiopen/bdfg21.rs diff --git a/src/pcs/kzg/multiopen/gwc19.rs b/snark-verifier/src/pcs/kzg/multiopen/gwc19.rs similarity index 100% rename from src/pcs/kzg/multiopen/gwc19.rs rename to snark-verifier/src/pcs/kzg/multiopen/gwc19.rs diff --git a/src/system.rs b/snark-verifier/src/system.rs similarity index 100% rename from src/system.rs rename to snark-verifier/src/system.rs diff --git a/src/system/halo2.rs b/snark-verifier/src/system/halo2.rs similarity index 99% rename from src/system/halo2.rs rename to snark-verifier/src/system/halo2.rs index 2c537f10..1ba7c1cc 100644 --- a/src/system/halo2.rs +++ b/snark-verifier/src/system/halo2.rs @@ -123,7 +123,7 @@ pub fn compile<'a, C: CurveAffine, P: Params<'a, C>>( let instance_committing_key = query_instance.then(|| { instance_committing_key( params, - polynomials.num_instance().into_iter().max().unwrap_or_default(), + Iterator::max(polynomials.num_instance().into_iter()).unwrap_or_default(), ) }); diff --git a/src/system/halo2/aggregation.rs b/snark-verifier/src/system/halo2/aggregation.rs similarity index 100% rename from src/system/halo2/aggregation.rs rename to snark-verifier/src/system/halo2/aggregation.rs diff --git a/src/system/halo2/strategy.rs b/snark-verifier/src/system/halo2/strategy.rs similarity index 100% rename from src/system/halo2/strategy.rs rename to snark-verifier/src/system/halo2/strategy.rs diff --git a/src/system/halo2/test.rs b/snark-verifier/src/system/halo2/test.rs similarity index 100% rename from src/system/halo2/test.rs rename to snark-verifier/src/system/halo2/test.rs diff --git a/src/system/halo2/test/circuit.rs b/snark-verifier/src/system/halo2/test/circuit.rs similarity index 100% rename from src/system/halo2/test/circuit.rs rename to snark-verifier/src/system/halo2/test/circuit.rs diff --git a/src/system/halo2/test/circuit/maingate.rs b/snark-verifier/src/system/halo2/test/circuit/maingate.rs similarity index 100% rename from src/system/halo2/test/circuit/maingate.rs rename to snark-verifier/src/system/halo2/test/circuit/maingate.rs diff --git a/src/system/halo2/test/circuit/standard.rs b/snark-verifier/src/system/halo2/test/circuit/standard.rs similarity index 100% rename from src/system/halo2/test/circuit/standard.rs rename to snark-verifier/src/system/halo2/test/circuit/standard.rs diff --git a/src/system/halo2/test/ipa.rs b/snark-verifier/src/system/halo2/test/ipa.rs similarity index 100% rename from src/system/halo2/test/ipa.rs rename to snark-verifier/src/system/halo2/test/ipa.rs diff --git a/src/system/halo2/test/ipa/native.rs b/snark-verifier/src/system/halo2/test/ipa/native.rs similarity index 100% rename from src/system/halo2/test/ipa/native.rs rename to snark-verifier/src/system/halo2/test/ipa/native.rs diff --git a/src/system/halo2/test/kzg.rs b/snark-verifier/src/system/halo2/test/kzg.rs similarity index 100% rename from src/system/halo2/test/kzg.rs rename to snark-verifier/src/system/halo2/test/kzg.rs diff --git a/src/system/halo2/test/kzg/evm.rs b/snark-verifier/src/system/halo2/test/kzg/evm.rs similarity index 100% rename from src/system/halo2/test/kzg/evm.rs rename to snark-verifier/src/system/halo2/test/kzg/evm.rs diff --git a/src/system/halo2/test/kzg/halo2.rs b/snark-verifier/src/system/halo2/test/kzg/halo2.rs similarity index 99% rename from src/system/halo2/test/kzg/halo2.rs rename to snark-verifier/src/system/halo2/test/kzg/halo2.rs index f120c82c..bd52426d 100644 --- a/src/system/halo2/test/kzg/halo2.rs +++ b/snark-verifier/src/system/halo2/test/kzg/halo2.rs @@ -368,7 +368,7 @@ impl Circuit for Accumulation { region, ContextParams { max_rows: config.base_field_config.range.gate.max_rows, - num_advice: vec![config.base_field_config.range.gate.num_advice], + num_context_ids: 1, fixed_columns: config.base_field_config.range.gate.constants.clone(), }, ); diff --git a/src/system/halo2/test/kzg/native.rs b/snark-verifier/src/system/halo2/test/kzg/native.rs similarity index 100% rename from src/system/halo2/test/kzg/native.rs rename to snark-verifier/src/system/halo2/test/kzg/native.rs diff --git a/src/system/halo2/transcript.rs b/snark-verifier/src/system/halo2/transcript.rs similarity index 100% rename from src/system/halo2/transcript.rs rename to snark-verifier/src/system/halo2/transcript.rs diff --git a/src/system/halo2/transcript/evm.rs b/snark-verifier/src/system/halo2/transcript/evm.rs similarity index 100% rename from src/system/halo2/transcript/evm.rs rename to snark-verifier/src/system/halo2/transcript/evm.rs diff --git a/src/system/halo2/transcript/halo2.rs b/snark-verifier/src/system/halo2/transcript/halo2.rs similarity index 100% rename from src/system/halo2/transcript/halo2.rs rename to snark-verifier/src/system/halo2/transcript/halo2.rs diff --git a/src/util.rs b/snark-verifier/src/util.rs similarity index 100% rename from src/util.rs rename to snark-verifier/src/util.rs diff --git a/src/util/arithmetic.rs b/snark-verifier/src/util/arithmetic.rs similarity index 100% rename from src/util/arithmetic.rs rename to snark-verifier/src/util/arithmetic.rs diff --git a/src/util/hash.rs b/snark-verifier/src/util/hash.rs similarity index 100% rename from src/util/hash.rs rename to snark-verifier/src/util/hash.rs diff --git a/src/util/hash/poseidon.rs b/snark-verifier/src/util/hash/poseidon.rs similarity index 100% rename from src/util/hash/poseidon.rs rename to snark-verifier/src/util/hash/poseidon.rs diff --git a/src/util/msm.rs b/snark-verifier/src/util/msm.rs similarity index 100% rename from src/util/msm.rs rename to snark-verifier/src/util/msm.rs diff --git a/src/util/poly.rs b/snark-verifier/src/util/poly.rs similarity index 100% rename from src/util/poly.rs rename to snark-verifier/src/util/poly.rs diff --git a/src/util/protocol.rs b/snark-verifier/src/util/protocol.rs similarity index 100% rename from src/util/protocol.rs rename to snark-verifier/src/util/protocol.rs diff --git a/src/util/transcript.rs b/snark-verifier/src/util/transcript.rs similarity index 100% rename from src/util/transcript.rs rename to snark-verifier/src/util/transcript.rs diff --git a/src/verifier.rs b/snark-verifier/src/verifier.rs similarity index 100% rename from src/verifier.rs rename to snark-verifier/src/verifier.rs diff --git a/src/verifier/plonk.rs b/snark-verifier/src/verifier/plonk.rs similarity index 100% rename from src/verifier/plonk.rs rename to snark-verifier/src/verifier/plonk.rs diff --git a/src/system/halo2/test/README.md b/src/system/halo2/test/README.md deleted file mode 100644 index 88becc84..00000000 --- a/src/system/halo2/test/README.md +++ /dev/null @@ -1,19 +0,0 @@ -In `plonk-verifier` root directory: - -1. Create `params` folder. Do not reuse params generated from other versions of `halo2_proofs` for now. - -2. Create `configs/verify_circuit.config`. - -3. Create `src/system/halo2/test/data` directory. Then run - -For single evm circuit verification: - -``` -cargo test --release -- --nocapture system::halo2::test::kzg::halo2::zkevm::test_shplonk_bench_evm_circuit --exact -``` - -For evm circuit + state circuit aggregation: - -``` -cargo test --release -- --nocapture system::halo2::test::kzg::halo2::zkevm::test_shplonk_bench_evm_and_state --exact -```