From c89793e27ac6645bbc997897ed05c9c769bd56b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Garillot?= Date: Sat, 16 Dec 2023 13:43:14 -0500 Subject: [PATCH] supernova This backports the following Arecibo PRs: - https://github.com/lurk-lab/arecibo/pull/2 - https://github.com/lurk-lab/arecibo/pull/3 - https://github.com/lurk-lab/arecibo/pull/10 - https://github.com/lurk-lab/arecibo/pull/16 - https://github.com/lurk-lab/arecibo/pull/23 - https://github.com/lurk-lab/arecibo/pull/30 - https://github.com/lurk-lab/arecibo/pull/28 - https://github.com/lurk-lab/arecibo/pull/41 - https://github.com/lurk-lab/arecibo/pull/45 - https://github.com/lurk-lab/arecibo/pull/50 - https://github.com/lurk-lab/arecibo/pull/56 - https://github.com/lurk-lab/arecibo/pull/51 - https://github.com/lurk-lab/arecibo/pull/72 - https://github.com/lurk-lab/arecibo/pull/92 - https://github.com/lurk-lab/arecibo/pull/95 - https://github.com/lurk-lab/arecibo/pull/97 - https://github.com/lurk-lab/arecibo/pull/101 - https://github.com/lurk-lab/arecibo/pull/110 - https://github.com/lurk-lab/arecibo/pull/106 - https://github.com/lurk-lab/arecibo/pull/112 - https://github.com/lurk-lab/arecibo/pull/114 - https://github.com/lurk-lab/arecibo/pull/119 - https://github.com/lurk-lab/arecibo/pull/120 - https://github.com/lurk-lab/arecibo/pull/127 - https://github.com/lurk-lab/arecibo/pull/123 - https://github.com/lurk-lab/arecibo/pull/131 - https://github.com/lurk-lab/arecibo/pull/174 - https://github.com/lurk-lab/arecibo/pull/175 - https://github.com/lurk-lab/arecibo/pull/182 Co-authored-by: WYATT Co-authored-by: Hanting Zhang Co-authored-by: Ming Co-authored-by: porcuquine Co-authored-by: Samuel Burnham <45365069+samuelburnham@users.noreply.github.com> Co-authored-by: Matej Penciak <96667244+mpenciak@users.noreply.github.com> Co-authored-by: Adrian Hamelink --- Cargo.toml | 20 +- benches/compressed-snark-supernova.rs | 328 ++++++ benches/recursive-snark-supernova.rs | 284 ++++++ notes/supernova.md | 313 ++++++ src/bellpepper/mod.rs | 2 +- src/bellpepper/r1cs.rs | 21 +- src/circuit.rs | 4 +- src/errors.rs | 3 + src/gadgets/ecc.rs | 8 +- src/gadgets/r1cs.rs | 129 ++- src/gadgets/utils.rs | 10 +- src/lib.rs | 135 ++- src/nifs.rs | 6 +- src/r1cs/mod.rs | 56 +- src/spartan/batched.rs | 621 ++++++++++++ src/spartan/batched_ppsnark.rs | 1352 +++++++++++++++++++++++++ src/spartan/direct.rs | 3 +- src/spartan/macros.rs | 104 ++ src/spartan/mod.rs | 216 ++-- src/spartan/polys/eq.rs | 26 +- src/spartan/polys/masked_eq.rs | 150 +++ src/spartan/polys/mod.rs | 1 + src/spartan/polys/multilinear.rs | 54 +- src/spartan/polys/power.rs | 28 +- src/spartan/polys/univariate.rs | 7 +- src/spartan/ppsnark.rs | 476 +++++---- src/spartan/snark.rs | 231 ++--- src/spartan/sumcheck.rs | 294 +++++- src/supernova/circuit.rs | 749 ++++++++++++++ src/supernova/error.rs | 19 + src/supernova/mod.rs | 1082 ++++++++++++++++++++ src/supernova/snark.rs | 749 ++++++++++++++ src/supernova/test.rs | 904 +++++++++++++++++ src/supernova/utils.rs | 179 ++++ src/traits/circuit_supernova.rs | 115 +++ src/traits/mod.rs | 1 + src/traits/snark.rs | 41 + 37 files changed, 8147 insertions(+), 574 deletions(-) create mode 100644 benches/compressed-snark-supernova.rs create mode 100644 benches/recursive-snark-supernova.rs create mode 100644 notes/supernova.md create mode 100644 src/spartan/batched.rs create mode 100644 src/spartan/batched_ppsnark.rs create mode 100644 src/spartan/macros.rs create mode 100644 src/spartan/polys/masked_eq.rs create mode 100644 src/supernova/circuit.rs create mode 100644 src/supernova/error.rs create mode 100644 src/supernova/mod.rs create mode 100644 src/supernova/snark.rs create mode 100644 src/supernova/test.rs create mode 100644 src/supernova/utils.rs create mode 100644 src/traits/circuit_supernova.rs diff --git a/Cargo.toml b/Cargo.toml index 7a7ca9c51..6b6cb6f5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,16 @@ byteorder = "1.4.3" thiserror = "1.0" halo2curves = { version = "0.4.0", features = ["derive_serde"] } group = "0.13.0" +pairing = "0.23.0" +abomonation = "0.7.3" +abomonation_derive = { version = "0.1.0", package = "abomonation_derive_ng" } +tap = "1.0.1" +cfg-if = "1.0.0" once_cell = "1.18.0" +itertools = "0.12.0" +rand = "0.8.5" +ref-cast = "1.0.20" +log = "0.4.20" [target.'cfg(any(target_arch = "x86_64", target_arch = "aarch64"))'.dependencies] pasta-msm = { version = "0.1.4" } @@ -48,10 +57,9 @@ criterion = { version = "0.4", features = ["html_reports"] } flate2 = "1.0" hex = "0.4.3" pprof = { version = "0.11" } -cfg-if = "1.0.0" sha2 = "0.10.7" +tracing-test = "0.2.4" proptest = "1.2.0" -rand = "0.8.5" [[bench]] name = "recursive-snark" @@ -69,6 +77,14 @@ harness = false name = "sha256" harness = false +[[bench]] +name = "recursive-snark-supernova" +harness = false + +[[bench]] +name = "compressed-snark-supernova" +harness = false + [features] default = [] asm = ["halo2curves/asm"] diff --git a/benches/compressed-snark-supernova.rs b/benches/compressed-snark-supernova.rs new file mode 100644 index 000000000..72dde46af --- /dev/null +++ b/benches/compressed-snark-supernova.rs @@ -0,0 +1,328 @@ +#![allow(non_snake_case)] +use nova_snark::{ + supernova::NonUniformCircuit, + supernova::{snark::CompressedSNARK, PublicParams, RecursiveSNARK}, + traits::{ + circuit_supernova::{StepCircuit, TrivialTestCircuit}, + snark::BatchedRelaxedR1CSSNARKTrait, + snark::RelaxedR1CSSNARKTrait, + Engine, + }, +}; +use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; +use core::marker::PhantomData; +use criterion::{measurement::WallTime, *}; +use ff::PrimeField; +use std::time::Duration; + +type E1 = nova_snark::provider::PallasEngine; +type E2 = nova_snark::provider::VestaEngine; +type EE1 = nova_snark::provider::ipa_pc::EvaluationEngine; +type EE2 = nova_snark::provider::ipa_pc::EvaluationEngine; +// SNARKs without computational commitments +type S1 = nova_snark::spartan::batched::BatchedRelaxedR1CSSNARK; +type S2 = nova_snark::spartan::snark::RelaxedR1CSSNARK; +// SNARKs with computational commitments +type SS1 = nova_snark::spartan::batched_ppsnark::BatchedRelaxedR1CSSNARK; +type SS2 = nova_snark::spartan::ppsnark::RelaxedR1CSSNARK; + +// To run these benchmarks, first download `criterion` with `cargo install cargo-criterion`. +// Then `cargo criterion --bench compressed-snark-supernova`. The results are located in `target/criterion/data/`. +// For flamegraphs, run `cargo criterion --bench compressed-snark-supernova --features flamegraph -- --profile-time `. +// The results are located in `target/criterion/profile/`. +cfg_if::cfg_if! { + if #[cfg(feature = "flamegraph")] { + criterion_group! { + name = compressed_snark_supernova; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)).with_profiler(pprof::criterion::PProfProfiler::new(100, pprof::criterion::Output::Flamegraph(None))); + targets = bench_one_augmented_circuit_compressed_snark, bench_two_augmented_circuit_compressed_snark, bench_two_augmented_circuit_compressed_snark_with_computational_commitments + } + } else { + criterion_group! { + name = compressed_snark_supernova; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_one_augmented_circuit_compressed_snark, bench_two_augmented_circuit_compressed_snark, bench_two_augmented_circuit_compressed_snark_with_computational_commitments + } + } +} + +criterion_main!(compressed_snark_supernova); + +// This should match the value in test_supernova_recursive_circuit_pasta +// TODO: This should also be a table matching the num_augmented_circuits in the below +const NUM_CONS_VERIFIER_CIRCUIT_PRIMARY: usize = 9844; +const NUM_SAMPLES: usize = 10; + +struct NonUniformBench +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S: StepCircuit + Default, +{ + num_circuits: usize, + num_cons: usize, + _p: PhantomData<(E1, E2, S)>, +} + +impl NonUniformBench +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S: StepCircuit + Default, +{ + fn new(num_circuits: usize, num_cons: usize) -> Self { + Self { + num_circuits, + num_cons, + _p: Default::default(), + } + } +} + +impl + NonUniformCircuit, TrivialTestCircuit> + for NonUniformBench +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S: StepCircuit + Default, +{ + fn num_circuits(&self) -> usize { + self.num_circuits + } + + fn primary_circuit(&self, circuit_index: usize) -> NonTrivialTestCircuit { + assert!( + circuit_index < self.num_circuits, + "Circuit index out of bounds: asked for {circuit_index}, but there are only {} circuits.", + self.num_circuits + ); + + NonTrivialTestCircuit::new(self.num_cons) + } + + fn secondary_circuit(&self) -> TrivialTestCircuit { + Default::default() + } +} + +/// Benchmarks the compressed SNARK at a provided number of constraints +/// +/// Parameters +/// - `num_augmented_circuits`: the number of augmented circuits in this configuration +/// - `group`: the criterion benchmark group +/// - `num_cons`: the number of constraints in the step circuit +fn bench_compressed_snark_internal_with_arity< + S1: BatchedRelaxedR1CSSNARKTrait, + S2: RelaxedR1CSSNARKTrait, +>( + group: &mut BenchmarkGroup<'_, WallTime>, + num_augmented_circuits: usize, + num_cons: usize, +) { + let bench: NonUniformBench::Scalar>> = + NonUniformBench::new(num_augmented_circuits, num_cons); + let pp = PublicParams::setup(&bench, &*S1::ck_floor(), &*S2::ck_floor()); + + let num_steps = 3; + let z0_primary = vec![::Scalar::from(2u64)]; + let z0_secondary = vec![::Scalar::from(2u64)]; + let mut recursive_snark_option: Option> = None; + let mut selected_augmented_circuit = 0; + + for _ in 0..num_steps { + let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { + RecursiveSNARK::new( + &pp, + &bench, + &bench.primary_circuit(0), + &bench.secondary_circuit(), + &z0_primary, + &z0_secondary, + ) + .unwrap() + }); + + if selected_augmented_circuit == 0 || selected_augmented_circuit == 1 { + let res = recursive_snark.prove_step( + &pp, + &bench.primary_circuit(selected_augmented_circuit), + &bench.secondary_circuit(), + ); + res.expect("Prove step failed"); + + let res = recursive_snark.verify(&pp, &z0_primary, &z0_secondary); + res.expect("Verify failed"); + } else { + unimplemented!() + } + + selected_augmented_circuit = (selected_augmented_circuit + 1) % num_augmented_circuits; + recursive_snark_option = Some(recursive_snark) + } + + assert!(recursive_snark_option.is_some()); + let recursive_snark = recursive_snark_option.unwrap(); + + let (prover_key, verifier_key) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); + + // Benchmark the prove time + group.bench_function("Prove", |b| { + b.iter(|| { + assert!(CompressedSNARK::<_, _, _, _, S1, S2>::prove( + black_box(&pp), + black_box(&prover_key), + black_box(&recursive_snark) + ) + .is_ok()); + }) + }); + + let res = CompressedSNARK::<_, _, _, _, S1, S2>::prove(&pp, &prover_key, &recursive_snark); + + assert!(res.is_ok()); + let compressed_snark = res.unwrap(); + + // Benchmark the verification time + group.bench_function("Verify", |b| { + b.iter(|| { + assert!(black_box(&compressed_snark) + .verify( + black_box(&pp), + black_box(&verifier_key), + black_box(&z0_primary), + black_box(&z0_secondary), + ) + .is_ok()); + }) + }); +} + +fn bench_one_augmented_circuit_compressed_snark(c: &mut Criterion) { + // we vary the number of constraints in the step circuit + for &num_cons_in_augmented_circuit in [ + NUM_CONS_VERIFIER_CIRCUIT_PRIMARY, + 16384, + 32768, + 65536, + 131072, + 262144, + 524288, + 1048576, + ] + .iter() + { + // number of constraints in the step circuit + let num_cons = num_cons_in_augmented_circuit - NUM_CONS_VERIFIER_CIRCUIT_PRIMARY; + + let mut group = c.benchmark_group(format!( + "CompressedSNARKSuperNova-1circuit-StepCircuitSize-{num_cons}" + )); + group.sample_size(NUM_SAMPLES); + + bench_compressed_snark_internal_with_arity::(&mut group, 1, num_cons); + + group.finish(); + } +} + +fn bench_two_augmented_circuit_compressed_snark(c: &mut Criterion) { + // we vary the number of constraints in the step circuit + for &num_cons_in_augmented_circuit in [ + NUM_CONS_VERIFIER_CIRCUIT_PRIMARY, + 16384, + 32768, + 65536, + 131072, + 262144, + 524288, + 1048576, + ] + .iter() + { + // number of constraints in the step circuit + let num_cons = num_cons_in_augmented_circuit - NUM_CONS_VERIFIER_CIRCUIT_PRIMARY; + + let mut group = c.benchmark_group(format!( + "CompressedSNARKSuperNova-2circuit-StepCircuitSize-{num_cons}" + )); + group.sample_size(NUM_SAMPLES); + + bench_compressed_snark_internal_with_arity::(&mut group, 2, num_cons); + + group.finish(); + } +} + +fn bench_two_augmented_circuit_compressed_snark_with_computational_commitments(c: &mut Criterion) { + // we vary the number of constraints in the step circuit + for &num_cons_in_augmented_circuit in [ + NUM_CONS_VERIFIER_CIRCUIT_PRIMARY, + 16384, + 32768, + 65536, + 131072, + 262144, + 524288, + 1048576, + ] + .iter() + { + // number of constraints in the step circuit + let num_cons = num_cons_in_augmented_circuit - NUM_CONS_VERIFIER_CIRCUIT_PRIMARY; + + let mut group = c.benchmark_group(format!( + "CompressedSNARKSuperNova-Commitments-2circuit-StepCircuitSize-{num_cons}" + )); + group.sample_size(NUM_SAMPLES); + + bench_compressed_snark_internal_with_arity::(&mut group, 2, num_cons); + + group.finish(); + } +} +#[derive(Clone, Debug, Default)] +struct NonTrivialTestCircuit { + num_cons: usize, + _p: PhantomData, +} + +impl NonTrivialTestCircuit +where + F: PrimeField, +{ + pub fn new(num_cons: usize) -> Self { + Self { + num_cons, + _p: Default::default(), + } + } +} +impl StepCircuit for NonTrivialTestCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + 0 + } + + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + // Consider a an equation: `x^{2 * num_cons} = y`, where `x` and `y` are respectively the input and output. + let mut x = z[0].clone(); + let mut y = x.clone(); + for i in 0..self.num_cons { + y = x.square(cs.namespace(|| format!("x_sq_{i}")))?; + x = y.clone(); + } + Ok((pc.cloned(), vec![y])) + } +} diff --git a/benches/recursive-snark-supernova.rs b/benches/recursive-snark-supernova.rs new file mode 100644 index 000000000..8a43918d4 --- /dev/null +++ b/benches/recursive-snark-supernova.rs @@ -0,0 +1,284 @@ +#![allow(non_snake_case)] +use nova_snark::{ + provider::{PallasEngine, VestaEngine}, + supernova::NonUniformCircuit, + supernova::{PublicParams, RecursiveSNARK}, + traits::{ + circuit_supernova::{StepCircuit, TrivialTestCircuit}, + snark::default_ck_hint, + Engine, + }, +}; +use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; +use core::marker::PhantomData; +use criterion::{measurement::WallTime, *}; +use ff::PrimeField; +use std::time::Duration; + +// To run these benchmarks, first download `criterion` with `cargo install cargo-criterion`. +// Then `cargo criterion --bench recursive-snark-supernova`. The results are located in `target/criterion/data/`. +// For flamegraphs, run `cargo criterion --bench recursive-snark-supernova --features flamegraph -- --profile-time `. +// The results are located in `target/criterion/profile/`. +cfg_if::cfg_if! { + if #[cfg(feature = "flamegraph")] { + criterion_group! { + name = recursive_snark_supernova; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)).with_profiler(pprof::criterion::PProfProfiler::new(100, pprof::criterion::Output::Flamegraph(None))); + targets = bench_one_augmented_circuit_recursive_snark, bench_two_augmented_circuit_recursive_snark + } + } else { + criterion_group! { + name = recursive_snark_supernova; + config = Criterion::default().warm_up_time(Duration::from_millis(3000)); + targets = bench_one_augmented_circuit_recursive_snark, bench_two_augmented_circuit_recursive_snark + } + } +} + +criterion_main!(recursive_snark_supernova); + +// This should match the value in test_supernova_recursive_circuit_pasta +// TODO: This should also be a table matching the num_augmented_circuits in the below +const NUM_CONS_VERIFIER_CIRCUIT_PRIMARY: usize = 9844; +const NUM_SAMPLES: usize = 10; + +struct NonUniformBench +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S: StepCircuit + Default, +{ + num_circuits: usize, + num_cons: usize, + _p: PhantomData<(E1, E2, S)>, +} + +impl NonUniformBench +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S: StepCircuit + Default, +{ + fn new(num_circuits: usize, num_cons: usize) -> Self { + Self { + num_circuits, + num_cons, + _p: Default::default(), + } + } +} + +impl + NonUniformCircuit, TrivialTestCircuit> + for NonUniformBench +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S: StepCircuit + Default, +{ + fn num_circuits(&self) -> usize { + self.num_circuits + } + + fn primary_circuit(&self, circuit_index: usize) -> NonTrivialTestCircuit { + assert!(circuit_index < self.num_circuits); + + NonTrivialTestCircuit::new(self.num_cons) + } + + fn secondary_circuit(&self) -> TrivialTestCircuit { + Default::default() + } +} + +/// Benchmarks the compressed SNARK at a provided number of constraints +/// +/// Parameters +/// - `num_augmented_circuits`: the number of augmented circuits in this configuration +/// - `group`: the criterion benchmark group +/// - `num_cons`: the number of constraints in the step circuit +fn bench_recursive_snark_internal_with_arity( + group: &mut BenchmarkGroup<'_, WallTime>, + num_augmented_circuits: usize, + num_cons: usize, +) { + let bench: NonUniformBench< + PallasEngine, + VestaEngine, + TrivialTestCircuit<::Scalar>, + > = NonUniformBench::new(2, num_cons); + let pp = PublicParams::setup(&bench, &*default_ck_hint(), &*default_ck_hint()); + + // Bench time to produce a recursive SNARK; + // we execute a certain number of warm-up steps since executing + // the first step is cheaper than other steps owing to the presence of + // a lot of zeros in the satisfying assignment + let num_warmup_steps = 10; + let z0_primary = vec![::Scalar::from(2u64)]; + let z0_secondary = vec![::Scalar::from(2u64)]; + let mut recursive_snark_option: Option> = None; + let mut selected_augmented_circuit = 0; + + for _ in 0..num_warmup_steps { + let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { + RecursiveSNARK::new( + &pp, + &bench, + &bench.primary_circuit(0), + &bench.secondary_circuit(), + &z0_primary, + &z0_secondary, + ) + .unwrap() + }); + + if selected_augmented_circuit == 0 || selected_augmented_circuit == 1 { + recursive_snark + .prove_step( + &pp, + &bench.primary_circuit(selected_augmented_circuit), + &bench.secondary_circuit(), + ) + .expect("Prove step failed"); + + recursive_snark + .verify(&pp, &z0_primary, &z0_secondary) + .expect("Verify failed"); + } else { + unimplemented!() + } + + selected_augmented_circuit = (selected_augmented_circuit + 1) % num_augmented_circuits; + recursive_snark_option = Some(recursive_snark) + } + + assert!(recursive_snark_option.is_some()); + let recursive_snark = recursive_snark_option.unwrap(); + + // Benchmark the prove time + group.bench_function("Prove", |b| { + b.iter(|| { + // produce a recursive SNARK for a step of the recursion + assert!(black_box(&mut recursive_snark.clone()) + .prove_step( + black_box(&pp), + &bench.primary_circuit(0), + &bench.secondary_circuit(), + ) + .is_ok()); + }) + }); + + // Benchmark the verification time + group.bench_function("Verify", |b| { + b.iter(|| { + assert!(black_box(&mut recursive_snark.clone()) + .verify( + black_box(&pp), + black_box(&[::Scalar::from(2u64)]), + black_box(&[::Scalar::from(2u64)]), + ) + .is_ok()); + }); + }); +} + +fn bench_one_augmented_circuit_recursive_snark(c: &mut Criterion) { + // we vary the number of constraints in the step circuit + for &num_cons_in_augmented_circuit in [ + NUM_CONS_VERIFIER_CIRCUIT_PRIMARY, + 16384, + 32768, + 65536, + 131072, + 262144, + 524288, + 1048576, + ] + .iter() + { + // number of constraints in the step circuit + let num_cons = num_cons_in_augmented_circuit - NUM_CONS_VERIFIER_CIRCUIT_PRIMARY; + + let mut group = c.benchmark_group(format!( + "RecursiveSNARKSuperNova-1circuit-StepCircuitSize-{num_cons}" + )); + group.sample_size(NUM_SAMPLES); + + bench_recursive_snark_internal_with_arity(&mut group, 1, num_cons); + group.finish(); + } +} + +fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) { + // we vary the number of constraints in the step circuit + for &num_cons_in_augmented_circuit in [ + NUM_CONS_VERIFIER_CIRCUIT_PRIMARY, + 16384, + 32768, + 65536, + 131072, + 262144, + 524288, + 1048576, + ] + .iter() + { + // number of constraints in the step circuit + let num_cons = num_cons_in_augmented_circuit - NUM_CONS_VERIFIER_CIRCUIT_PRIMARY; + + let mut group = c.benchmark_group(format!( + "RecursiveSNARKSuperNova-2circuit-StepCircuitSize-{num_cons}" + )); + group.sample_size(NUM_SAMPLES); + + bench_recursive_snark_internal_with_arity(&mut group, 2, num_cons); + group.finish(); + } +} + +#[derive(Clone, Debug, Default)] +struct NonTrivialTestCircuit { + num_cons: usize, + _p: PhantomData, +} + +impl NonTrivialTestCircuit +where + F: PrimeField, +{ + pub fn new(num_cons: usize) -> Self { + Self { + num_cons, + _p: Default::default(), + } + } +} +impl StepCircuit for NonTrivialTestCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + 0 + } + + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + // Consider a an equation: `x^2 = y`, where `x` and `y` are respectively the input and output. + let mut x = z[0].clone(); + let mut y = x.clone(); + for i in 0..self.num_cons { + y = x.square(cs.namespace(|| format!("x_sq_{i}")))?; + x = y.clone(); + } + Ok((pc.cloned(), vec![y])) + } +} diff --git a/notes/supernova.md b/notes/supernova.md new file mode 100644 index 000000000..e8090b5af --- /dev/null +++ b/notes/supernova.md @@ -0,0 +1,313 @@ +# SuperNova Description + +This document explains from a high-level how the SuperNova protocol was implemented in Arecibo. +We aim to provide a mathematical description of the protocol, as it is implemented, and highlight the differences with the original paper. + +## Terminology and Concept Clarifications + +Before delving into the specifics of the implementation, it's crucial to define and clarify some key terms and concepts used throughout this document: + +- **Recursive SNARK**: A Recursive SNARK is a type of succinct non-interactive argument of knowledge for a circuit $F$ which can be composed with itself as $z\_{i+1} \gets F(z_i)$. +Each iteration proves the verification of a proof for $z_i$ and the correctness of $z\_{i+1}$, ensuring the proving of each step remains constant. +- **Augmentation Circuit**: In the context of the SuperNova protocol, an augmentation circuit refers to a circuit $F'$ composing $F$ with a circuit which partially verifies the validity of the previous output $z_i$ before running $F(z_i)$. +- **NIFS Folding Verifier**: A non-interactive folding scheme is a protocol for efficiently updating a proof $\pi_i$ about an iterated function $z\_{i+1} \gets F(z_i)$ into a new proof $\pi\_{i+1}$, through a process referred to as "folding". +By splitting the proof into an instance/witness pair $(u,w) = \pi$, the folding verifier describes an algorithm for verifying that the $u$ component was properly updated. + +## SuperNova vs. Nova + +The main improvement of SuperNova, is to allow each iteration to apply one of several functions to the previous output, whereas Nova only supported the iteration of a single function. + +Let $F_0, \ldots, F\_{\ell-1}$ be folding circuits with the same arity $a$. +In the context of SuperNova, this means that each $F_j$ takes $a$ inputs from the previous iteration, and returns $a$ outputs. +These circuits implement the `circuit_supernova::StepCircuit` trait, where the main differences with the existing `StepCircuit` trait are +- The circuit $F_j$ provides its `circuit_index` $j$ +- The `synthesize` function upon input $z_i$ returns the next `program_counter` $\mathsf{pc}\_{i+1}$ alongside the output $z\_{i+1}$. It also accepts the (optional) input program counter $\mathsf{pc}_i$, which can be `None` when $\ell=1$. During circuit synthesis, a constraint enforces $\mathsf{pc}_i \equiv j$. In contrast to the paper, the _predicate_ function $\varphi$ is built into the circuit itself. In other words, we have the signature $(\mathsf{pc}\_{i+1}, z\_{i+1}) \gets F\_{j}(\mathsf{pc}\_{i}, z\_{i})$. + +The goal is to efficiently prove the following computation: +```ignore +pc_i = pc_0 +z_i = z_0 +for i in 0..num_steps + (pc_i, z_i) = F_{pc_i}(z_i) +return z_i +``` + +## Cycles of Curves + +"Cycles of Curves" describes a technique for more efficiently verifying the output $z_i$ of the previous circuit iteration, by running the verification on a curve whose base/scalar fields are inverted. +The result is that the elliptic curve scalar multiplications in the algorithm can be computed in the "native field" of the circuit, minimizing the need for expensive non-native field arithmetic. + +While the original Nova implementation allows computation to be done on both curves, the SuperNova implementation only uses the cycle curve to verify the computation performed on the primary curve. + +## Prover state + +The prover needs to store data about the previous function iteration. It is defined by the `supernova::RecursiveSNARK` struct. It contains: + +- $i$: the number of iterations performed. +Note that the `new` constructor actually performs the first iteration, and the first call to `prove_step` simply sets the counter to 1. +- Primary curve: + - $(\mathsf{pc}_i, z_0, z_i)$: current program counter and inputs for the primary circuit + - $U[\ ],W[\ ]$: List of relaxed instance/witness pairs for all the circuits on the primary curve. + These can be `None` when the circuit for that pair has not yet been executed. + The last updated entry is the result of having folded a proof for the correctness of $z_i$. +- Secondary curve + - $(z_0', z_i')$: Inputs for the single circuit on the secondary curve. + - $u',w'$: Proof for the correctness of the circuit that produced $z_i'$ + - $U', W'$: Relaxed instance/witness pair into which $(u', w')$ will be folded into in the next iteration. + +Due to the particularities of the cycles of curves implementation, the outputs of the circuits producing $(z_i, z_i')$ are encoded in public outputs of the proof $u'$. + + +## Prove Step + +At each step, the prover needs to: +- Create a proof $T'$ for folding $u'$ into $U'$, producing $U'\_{next}$. +- Create a proof $(u,w)$ on the primary curve for the statements: + - $(\mathsf{pc}\_{i+1}, z\_{i+1}) \gets F_\mathsf{pc_i}(z_i)$ + - Verifying the folding of $u'$ into $U'$ with $T'$ +- Create a proof $T$ for folding $u$ into $U[\mathsf{pc}_i]$, producing $U\_{next}$ +- Create a proof $(u'\_{next}, w'\_{next})$ for the verification on the secondary curve + - Verifying the folding of $u$ into $U[\mathsf{pc}_i]$ with $T$ +- Update the state of the claims + - $U' = U'\_{next}$, $W' = W'\_{next}$ + - $U[\mathsf{pc}_i] = U\_{next}$, $W[\mathsf{pc}_i] = W\_{next}$ + - $(u',w') = (u'\_{next}, w'\_{next})$ +- Save $z\_{i+1},z'\_{i+1}, \mathsf{pc}\_{i+1}$ as inputs for the next iteration. + +In pseudocode, `prove_step` looks something like: + +```ignore +if i = 0 { + U[] = [ø;l] + + // Create a proof for the first iteration of F on the primary curve + (pc_1, z_1), (u_1, w_1) <- Prove(F_{pc0}, + i=0, + pc_0, + z_0, + _, // z_i : z_0 is the input used + _, // U' : Existing accumulator is empty + _, // u' : No proof of the secondary curve to verify + _, // T' : Nothing to fold + 0, // index of u' in U' + ) + // The circuit output is [ vk, i=1, pc_1, z_0, z_1, U'=ø ] + // Update state to catch up with verifier + z_i = z_1 + pc_i = pc_1 + U' = ø + W' = ø + + // Create proof on secondary curve + // verifying the validity of the first proof + z'_1, (u'_1, w'_1) <- Prove(F', + i, + 0, // pc is always 0 on secondary curve + z'_0, + _, // z'_i : z'_0 is the input used + _, // U[]: all accumulators on primary curve are empty + u_0, // proof for z1 + _, // T: u_0 is directly included into U[pc0] + pc_1, // index of u_0 in U[] + ) + // The circuit outputs [ vk, i=1, z'_0, z'_1, U_next[] ] + // Update state to catch up with verifier + z_i' = z_1' + U[pc_1] = u_1 + W[pc_1] = w_1 + + // Save the proof of F' to be folded into U' in the next iteration + u' = u'_1 + w' = w'_1 +} else { + // Create folding proof for u' into U', producing U'_next + (U'_next, W'_next), T' <- NIFS.Prove(U', W', u', w') + + // Create a proof for the next iteration of F on the primary curve + (pc_next, z_next), (u_new, w_new) <- Prove(F_{pc_i}, + i, + pc_i, + z_0, + z_i, + [U'], + u', + T', + 0, // index of u' in [U'] is always 0 + ) + // The circuit outputs [ vk, i+1, pc_next, z_0, z_next, U'_next ] + // Update state to catch up with verifier + z_i = z_next + pc_i = pc_next + U' = U'_next + W' = W'_next + + // Create folding proof for u_new into U[pci], producing U_next + (U_next, W_next), T <- NIFS.Prove(U[pci], W[pci], u_new, w_new) + + // Create proof on secondary curve + // verifying the folding of u_next into + z'_next, (u'_next, w'_next) <- Prove(F', + i, + 0, // pc is always 0 on secondary curve + z_0', + z_i', + U[], + u_new, + T, + pc_i, // Index of u_new in U[] + ) + // The circuit outputs [ vk, i+1, z'_0, z'_next, U_next[] ] + // Update state to catch up with verifier + z_i' = z'_next + U[pc_next] = U_next + W[pc_next] = W_next + + // Save the proof of F' to be folded into U' in the next iteration + u' = u'_next + w' = w'_next +} +``` + +Each iteration stops when the prover has produced a valid R1CS instance $(u',w')$ for the secondary circuit, just before folding it back into its accumulator $(U',W')$ in the next iteration. +This allows us to access the public outputs of the secondary circuit in the next iteration, or when verifying the IVC chain. + +## Augmented Circuit + +During each proof iteration, the circuits evaluated and proved by the prover need to be *augmented* to include additional constraints which verify that the previous iteration was correctly accumulated. + +To minimize code duplication, there is only a single version of the recursive verification circuit. The circuit is customized depending on whether it is synthesized on the primary/secondary curve. + +### Input Allocation + +The inputs of provided to the augmented step circuit $F'_j$ are: + +**Inputs for step circuit** +- $\mathsf{vk}$: a digest of the verification key for the final compressing SNARK (which includes all public parameters of all circuits) +- $i \in \mathbb{Z}\_{\geq 0}$: the number of iteration of the functions before running $F$ +- $\mathsf{pc}_i \in [\ell]$: index of the current function being executed + - **Primary**: The program counter $\mathsf{pc}_i$ must always be `Some`, and through the `EnforcingStepCircuit` trait, we enforce $\mathsf{pc}_i \equiv j$. + - **Secondary**: Always `None`, and interpreted as $\mathsf{pc}_i \equiv 0$, since there is only a single circuit. +- $z_0 \in \mathbb{F}^a$: inputs for the first iteration of $F$ +- $z_i \in \mathbb{F}^a$: inputs for the current iteration of $F$ + - **Base case**: Set to `None`, in which case it is allocated as $[0]$, and $z_0$ is used as $z_i$. +- $U_i[\ ] \in \mathbb{R}'^\ell$: list of relaxed R1CS instances on the other curve + - **Primary**: Since there is only a single circuit on the secondary curve, we have $\ell = 0$ and therefore $U_i[\ ]$ only contains a single `RelaxedR1CSInstance`. + - **Secondary**: The list of input relaxed instances $U_i[\ ]$ is initialized by passing a slice `[Option>]`, one for each circuit on the primary curve. + Since some of these instances do not exist yet (i.e. for circuits which have not been executed yet), the `None` entries are allocated as a default instance. + +To minimize the cost related to handling public inputs/outputs of the circuit, these values are hashed as $H(\mathsf{vk}, i, \mathsf{pc}_i, z_0, z_i, U_i[\ ])$. +In the first iteration though, the hash comparison is skipped, and the optional values are conditionally replaced with constrained default values. + +**Auxiliary inputs for recursive verification of other the curve's circuit** +- $u \in \mathbb{R}'$: fresh R1CS instance for the previous iteration on the other curve + - Contains the public outputs of the 2 previous circuits on the different curves. + - **Base case -- Primary**: Set to `None`, since there is no proof of the secondary curve to fold +- $T \in \mathbb{G}'$: Proof for folding $u$ into $U[\mathsf{pc}']$. + - **Base case -- Primary**: Set to `None`, since there is no proof of the secondary curve to fold +- $\mathsf{pc}' \in [\ell]$: index of the previously executed function on the other curve. + - **Primary**: Always 0 since the program counter on the secondary curve is always 0 + - **Secondary**: Equal to the program counter of the last function proved on the primary curve. + +These non-deterministic inputs are used to compute the circuit's outputs. +When they are empty, we allocate checked default values instead. +We also check that the computed hash of the inputs matches the hash of the output of the previous iteration contained in $u$. + +**Outputs** +- $\mathsf{vk}$: passed along as-is +- $i+1 \in \mathbb{Z}\_{\geq 0}$: the incremented number of iterations +- $\mathsf{pc}\_{i+1} \in [\ell]$: index of next function to execute +- $z_0 \in \mathbb{F}^a$: passed along as-is +- $z\_{i+1} \in \mathbb{F}^a$: output of the execution $F\_{\mathsf{pc}_i}$ +- $U\_{i+1}[\ ] \in \mathbb{R}'^\ell$: Updated list of Relaxed R1CS instances, reflecting the folding of $u$ into $U_i[\mathsf{pc}']$ + - **Primary**: Since no input proof was provided, we set $U_1$ to the default initial instance. + +All these values should be computed deterministically from the inputs described above (even if just passed along as-is). +The actual public output is the hash of these values, to be consistent with the encoding of the inputs. + +### Constraints + +The circuit has a branching depending on whether it is verifying the first iteration of the IVC chain. Each branch computes the next list of instances $U\_{i+1}[\ ]$. + +#### Branch: i>0 `synthesize_non_base_case` + +The verification circuit first checks that the public output $u.X_0$ is equal to the hash of all outputs of the previous circuit iteration. +Note that this value is defined as a public output of the proof $u$ on the other curve. +It was simply passed along unverified by the cycle circuit to link the two circuits from the same curve. +Since the base case does not have any previous input, we only check the hash if $i>0$. +The circuit produces a bit corresponding to: + +$$b\_{X_0} \gets X_0 \stackrel{?}{=} H(\mathsf{vk}, i, \mathsf{pc}_i, z_0, z_i, U_i[\ ])$$ + +This bit is checked later on. + +The circuit extracts $U_i[\mathsf{pc}']$ by using conditional selection on $U_i[\ ]$. +This is done by computing a selector vector $s \in \{0,1\}^\ell$ such that $s\_{\mathsf{pc}'}=1$ and all other entries are 0. + +The instance new folding instance $U\_{i+1}[\mathsf{pc}']$ is produced by running the NIFS folding verifier: + +$$ +U\_{i+1}[\mathsf{pc}'] \gets \mathsf{NIFS}.\mathsf{Verify}(\mathsf{vk}, u, U[\mathsf{pc}'], T) +$$ + +A new list of accumulators $U\_{i+1}[\ ]$ is then obtained using conditional selection. +This branch returns $U\_{i+1}[\ ]$, $b\_{X_0}$ as well as the selector $s$. + +#### Branch: i=0 (`synthesize_base_case`) + +If $i \equiv 0$, then the verification circuit must instantiate the inputs as their defaults. +Namely, it initializes a list $U_0[\ ]$ (different from the input list which is given to the previous branch) with "empty instances" (all group elements are set to the identity). + +The ouptut list of instances $U_1[\ ]$ is +- **Primary curve**: the incomming proof $u$ is trivial, so the result of folding two trivial instances is defined as the trivial relaxed instance. +- **Secondary curve**: the instance $U_0[\mathsf{pc}']$ is simply replaced with the relaxation of $u$ using conditional selection. + +This branch returns $U_1[\ ]$. + +#### Remaining constraints + +Having run both branches, the circuit has computed +- $U\_{i+1}[\ ], b\_{X_0}, s$ from the first branch +- $U_1[\ ]$ from the second branch + +- Using the bit $b\_{i=0} \gets i \stackrel{?}{=} 0$, it needs to conditionally select which list of instance to return. + - $U\_{i+1} \gets b\_{i=0} \ \ ?\ \ U\_{1}[\ ] \ \ :\ \ U\_{i+1}[\ ]$ +- Check that $(i\neq 0) \implies b\_{X_0}$, enforcing that the hash is correct when not handling the base case + - $b\_{i=0} \lor b\_{X_0}$ +- Select + - $z_i \gets b\_{i=0} \ \ ?\ \ z_0 \ \ :\ \ z_i$ +- Enforce circuit selection + - $\mathsf{pc}\_{i} \equiv j$ +- Compute next output + - $(\mathsf{pc}\_{i+1}, z\_{i+1}) \gets F_j(z_i)$ + + +### Public Outputs + +The output at this point would be + +$$ +\Big (i+1, \mathsf{pc}\_{i+1}, z_0, z\_{i+1}, U\_{i+1}\Big) +$$ + +To keep the number of public outputs small, the outputs of the circuit are hashed into a single field element. We create this hash as $H\_{out} = H\big (\mathsf{vk}, i+1, \mathsf{pc}\_{i+1}, z_0, z\_{i+1}, U\_{next}\big)$. + +We also return the hash resulting from the output on the other curve, $u.X_1$. It will be unpacked at the start of the next iteration of the circuit on the cycle curve, so we swap it and place it first. The actual public output is then. + +$$ +[u.X_1, H\_{out}] +$$ + +We can view output as the shared state between the circuits on the two curve. The list of two elements is a queue, where the last inserted element is popped out to be consumed by the verification circuit, and the resulting output is added to the end of the queue. + +## Verification + +After any number of iterations of `prove_step`, we can check that the current prover state is correct. In particular, we want to ensure that $(z_i, z'_i)$ are the correct outputs after having run $i$ iterations of the folding prover. + +To verify that $(z_i, z'_i)$ are correct, the verifier needs to recompute the public outputs of the latest proof $u'$. Since this is the output on the secondary curve, the first entry $u'.X_0$ will be the output of the primary curve circuit producing $(\mathsf{pc}_i, z_i)$ and the accumulator $U'$ in which we will fold $u'$. The second entry $u'.X_1$ is the output of the last circuit on the secondary curve, which will have folded the proof for $(\mathsf{pc}_i, z_i)$ into $U[\ ]$. + +- $u'.X_0 \stackrel{?}{=} H(\mathsf{vk}, i, \mathsf{pc}_i, z_0, z_i, U')$ +- $u'.X_1 \stackrel{?}{=} H'(\mathsf{vk}, i, z'_0, z'_i, U[\ ])$ + +We then verify that $(u',w')$ is a satisfying circuit, which proves that all relaxed instances $U[\ ], U'$ were correctly updated through by folding proof. + +We then need to verify that all accumulators $(U[\ ], W[\ ])$ and $(U', W')$ are correct by checking the circuit satisfiability. \ No newline at end of file diff --git a/src/bellpepper/mod.rs b/src/bellpepper/mod.rs index 1309fe23d..6660c4dcd 100644 --- a/src/bellpepper/mod.rs +++ b/src/bellpepper/mod.rs @@ -45,7 +45,7 @@ mod tests { // First create the shape let mut cs: ShapeCS = ShapeCS::new(); synthesize_alloc_bit(&mut cs); - let (shape, ck) = cs.r1cs_shape(&*default_ck_hint()); + let (shape, ck) = cs.r1cs_shape_and_key(&*default_ck_hint()); // Now get the assignment let mut cs = SatisfyingAssignment::::new(); diff --git a/src/bellpepper/r1cs.rs b/src/bellpepper/r1cs.rs index 278390bd9..fe6df07c0 100644 --- a/src/bellpepper/r1cs.rs +++ b/src/bellpepper/r1cs.rs @@ -5,7 +5,7 @@ use super::{shape_cs::ShapeCS, solver::SatisfyingAssignment, test_shape_cs::TestShapeCS}; use crate::{ errors::NovaError, - r1cs::{CommitmentKeyHint, R1CSInstance, R1CSShape, R1CSWitness, SparseMatrix, R1CS}, + r1cs::{commitment_key, CommitmentKeyHint, R1CSInstance, R1CSShape, R1CSWitness, SparseMatrix}, traits::Engine, CommitmentKey, }; @@ -27,7 +27,14 @@ pub trait NovaShape { /// Return an appropriate `R1CSShape` and `CommitmentKey` structs. /// A `CommitmentKeyHint` should be provided to help guide the construction of the `CommitmentKey`. /// This parameter is documented in `r1cs::R1CS::commitment_key`. - fn r1cs_shape(&self, ck_hint: &CommitmentKeyHint) -> (R1CSShape, CommitmentKey); + fn r1cs_shape_and_key(&self, ck_hint: &CommitmentKeyHint) -> (R1CSShape, CommitmentKey) { + let S = self.r1cs_shape(); + let ck = commitment_key(&S, ck_hint); + + (S, ck) + } + /// Return an appropriate `R1CSShape`. + fn r1cs_shape(&self) -> R1CSShape; } impl NovaWitness for SatisfyingAssignment { @@ -53,10 +60,10 @@ macro_rules! impl_nova_shape { where E::Scalar: PrimeField, { - fn r1cs_shape(&self, ck_hint: &CommitmentKeyHint) -> (R1CSShape, CommitmentKey) { + fn r1cs_shape(&self) -> R1CSShape { let mut A = SparseMatrix::::empty(); let mut B = SparseMatrix::::empty(); - let mut C = SparseMatrix::::empty(); + let mut C: SparseMatrix<::Scalar> = SparseMatrix::::empty(); let mut num_cons_added = 0; let mut X = (&mut A, &mut B, &mut C, &mut num_cons_added); @@ -80,10 +87,8 @@ macro_rules! impl_nova_shape { C.cols = num_vars + num_inputs; // Don't count One as an input for shape's purposes. - let S = R1CSShape::new(num_constraints, num_vars, num_inputs - 1, A, B, C).unwrap(); - let ck = R1CS::::commitment_key(&S, ck_hint); - - (S, ck) + let res = R1CSShape::new(num_constraints, num_vars, num_inputs - 1, A, B, C); + res.unwrap() } } }; diff --git a/src/circuit.rs b/src/circuit.rs index 89ec97684..28f0c9177 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -396,7 +396,7 @@ mod tests { NovaAugmentedCircuit::new(primary_params, None, &tc1, ro_consts1.clone()); let mut cs: TestShapeCS = TestShapeCS::new(); let _ = circuit1.synthesize(&mut cs); - let (shape1, ck1) = cs.r1cs_shape(&*default_ck_hint()); + let (shape1, ck1) = cs.r1cs_shape_and_key(&*default_ck_hint()); assert_eq!(cs.num_constraints(), num_constraints_primary); let tc2 = TrivialCircuit::default(); @@ -405,7 +405,7 @@ mod tests { NovaAugmentedCircuit::new(secondary_params, None, &tc2, ro_consts2.clone()); let mut cs: TestShapeCS = TestShapeCS::new(); let _ = circuit2.synthesize(&mut cs); - let (shape2, ck2) = cs.r1cs_shape(&*default_ck_hint()); + let (shape2, ck2) = cs.r1cs_shape_and_key(&*default_ck_hint()); assert_eq!(cs.num_constraints(), num_constraints_secondary); // Execute the base case for the primary diff --git a/src/errors.rs b/src/errors.rs index fd6922082..121704b9e 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -21,6 +21,9 @@ pub enum NovaError { /// returned if the supplied witness is not a satisfying witness to a given shape and instance #[error("UnSat")] UnSat, + /// returned if the supplied witness is not a satisfying witness to a given shape and instance, with error constraint index + #[error("UnSatIndex")] + UnSatIndex(usize), /// returned when the supplied compressed commitment cannot be decompressed #[error("DecompressionError")] DecompressionError, diff --git a/src/gadgets/ecc.rs b/src/gadgets/ecc.rs index 1540ccdf9..fabc19039 100644 --- a/src/gadgets/ecc.rs +++ b/src/gadgets/ecc.rs @@ -1031,8 +1031,7 @@ mod tests { // First create the shape let mut cs: TestShapeCS = TestShapeCS::new(); let _ = synthesize_smul::(cs.namespace(|| "synthesize")); - println!("Number of constraints: {}", cs.num_constraints()); - let (shape, ck) = cs.r1cs_shape(&*default_ck_hint()); + let (shape, ck) = cs.r1cs_shape_and_key(&*default_ck_hint()); // Then the satisfying assignment let mut cs = SatisfyingAssignment::::new(); @@ -1088,7 +1087,7 @@ mod tests { let mut cs: TestShapeCS = TestShapeCS::new(); let _ = synthesize_add_equal::(cs.namespace(|| "synthesize add equal")); println!("Number of constraints: {}", cs.num_constraints()); - let (shape, ck) = cs.r1cs_shape(&*default_ck_hint()); + let (shape, ck) = cs.r1cs_shape_and_key(&*default_ck_hint()); // Then the satisfying assignment let mut cs = SatisfyingAssignment::::new(); @@ -1147,8 +1146,7 @@ mod tests { // First create the shape let mut cs: TestShapeCS = TestShapeCS::new(); let _ = synthesize_add_negation::(cs.namespace(|| "synthesize add equal")); - println!("Number of constraints: {}", cs.num_constraints()); - let (shape, ck) = cs.r1cs_shape(&*default_ck_hint()); + let (shape, ck) = cs.r1cs_shape_and_key(&*default_ck_hint()); // Then the satisfying assignment let mut cs = SatisfyingAssignment::::new(); diff --git a/src/gadgets/r1cs.rs b/src/gadgets/r1cs.rs index 8bc8a3b89..529de7408 100644 --- a/src/gadgets/r1cs.rs +++ b/src/gadgets/r1cs.rs @@ -56,6 +56,7 @@ impl AllocatedR1CSInstance { } /// An Allocated Relaxed R1CS Instance +#[derive(Clone)] pub struct AllocatedRelaxedR1CSInstance { pub(crate) W: AllocatedPoint, pub(crate) E: AllocatedPoint, @@ -323,45 +324,109 @@ impl AllocatedRelaxedR1CSInstance { /// If the condition is true then returns this otherwise it returns the other pub fn conditionally_select::Base>>( &self, - mut cs: CS, + cs: CS, other: &AllocatedRelaxedR1CSInstance, condition: &Boolean, ) -> Result, SynthesisError> { - let W = AllocatedPoint::conditionally_select( - cs.namespace(|| "W = cond ? self.W : other.W"), - &self.W, - &other.W, - condition, - )?; + conditionally_select_alloc_relaxed_r1cs(cs, self, other, condition) + } +} - let E = AllocatedPoint::conditionally_select( - cs.namespace(|| "E = cond ? self.E : other.E"), - &self.E, - &other.E, +/// c = cond ? a: b, where a, b: `AllocatedRelaxedR1CSInstance` +pub fn conditionally_select_alloc_relaxed_r1cs< + E: Engine, + CS: ConstraintSystem<::Base>, +>( + mut cs: CS, + a: &AllocatedRelaxedR1CSInstance, + b: &AllocatedRelaxedR1CSInstance, + condition: &Boolean, +) -> Result, SynthesisError> { + let c = AllocatedRelaxedR1CSInstance { + W: conditionally_select_point( + cs.namespace(|| "W = cond ? a.W : b.W"), + &a.W, + &b.W, condition, - )?; - - let u = conditionally_select( - cs.namespace(|| "u = cond ? self.u : other.u"), - &self.u, - &other.u, + )?, + E: conditionally_select_point( + cs.namespace(|| "E = cond ? a.E : b.E"), + &a.E, + &b.E, condition, - )?; - - let X0 = conditionally_select_bignat( - cs.namespace(|| "X[0] = cond ? self.X[0] : other.X[0]"), - &self.X0, - &other.X0, + )?, + u: conditionally_select( + cs.namespace(|| "u = cond ? a.u : b.u"), + &a.u, + &b.u, condition, - )?; - - let X1 = conditionally_select_bignat( - cs.namespace(|| "X[1] = cond ? self.X[1] : other.X[1]"), - &self.X1, - &other.X1, + )?, + X0: conditionally_select_bignat( + cs.namespace(|| "X[0] = cond ? a.X[0] : b.X[0]"), + &a.X0, + &b.X0, condition, - )?; + )?, + X1: conditionally_select_bignat( + cs.namespace(|| "X[1] = cond ? a.X[1] : b.X[1]"), + &a.X1, + &b.X1, + condition, + )?, + }; + Ok(c) +} - Ok(AllocatedRelaxedR1CSInstance { W, E, u, X0, X1 }) - } +/// c = cond ? a: b, where a, b: `Vec` +pub fn conditionally_select_vec_allocated_relaxed_r1cs_instance< + E: Engine, + CS: ConstraintSystem<::Base>, +>( + mut cs: CS, + a: &[AllocatedRelaxedR1CSInstance], + b: &[AllocatedRelaxedR1CSInstance], + condition: &Boolean, +) -> Result>, SynthesisError> { + a.iter() + .enumerate() + .zip(b.iter()) + .map(|((i, a), b)| { + a.conditionally_select( + cs.namespace(|| format!("cond ? a[{}]: b[{}]", i, i)), + b, + condition, + ) + }) + .collect::>, _>>() +} + + +/// c = cond ? a: b, where a, b: `AllocatedPoint` +pub fn conditionally_select_point::Base>>( + mut cs: CS, + a: &AllocatedPoint, + b: &AllocatedPoint, + condition: &Boolean, +) -> Result, SynthesisError> { + let c = AllocatedPoint { + x: conditionally_select( + cs.namespace(|| "x = cond ? a.x : b.x"), + &a.x, + &b.x, + condition, + )?, + y: conditionally_select( + cs.namespace(|| "y = cond ? a.y : b.y"), + &a.y, + &b.y, + condition, + )?, + is_infinity: conditionally_select( + cs.namespace(|| "is_infinity = cond ? a.is_infinity : b.is_infinity"), + &a.is_infinity, + &b.is_infinity, + condition, + )?, + }; + Ok(c) } diff --git a/src/gadgets/utils.rs b/src/gadgets/utils.rs index b4ee99f25..b8f242faf 100644 --- a/src/gadgets/utils.rs +++ b/src/gadgets/utils.rs @@ -148,15 +148,15 @@ pub fn alloc_num_equals>( let r = AllocatedBit::alloc(cs.namespace(|| "r"), r_value)?; - // Allocate t s.t. t=1 if z1 == z2 else 1/(z1 - z2) + // Allocate t s.t. t=1 if a == b else 1/(a - b) let t = AllocatedNum::alloc(cs.namespace(|| "t"), || { - Ok(if *a.get_value().get()? == *b.get_value().get()? { + let a_val = *a.get_value().get()?; + let b_val = *b.get_value().get()?; + Ok(if a_val == b_val { F::ONE } else { - (*a.get_value().get()? - *b.get_value().get()?) - .invert() - .unwrap() + (a_val - b_val).invert().unwrap() }) })?; diff --git a/src/lib.rs b/src/lib.rs index baa09b823..5fb0f292f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,26 +13,27 @@ // private modules mod bellpepper; mod circuit; -mod constants; mod digest; mod nifs; -mod r1cs; - +pub(crate) mod constants; +pub(crate) mod r1cs; // public modules pub mod errors; pub mod gadgets; pub mod provider; pub mod spartan; +pub mod supernova; pub mod traits; use once_cell::sync::OnceCell; -use crate::bellpepper::{ - r1cs::{NovaShape, NovaWitness}, - shape_cs::ShapeCS, - solver::SatisfyingAssignment, -}; use crate::digest::{DigestComputer, SimpleDigestible}; +use crate:: + bellpepper::{ + r1cs::{NovaShape, NovaWitness}, + shape_cs::ShapeCS, + solver::SatisfyingAssignment, + }; use bellpepper_core::ConstraintSystem; use circuit::{NovaAugmentedCircuit, NovaAugmentedCircuitInputs, NovaAugmentedCircuitParams}; use constants::{BN_LIMB_WIDTH, BN_N_LIMBS, NUM_FE_WITHOUT_IO_FOR_CRHF, NUM_HASH_BITS}; @@ -52,6 +53,32 @@ use traits::{ AbsorbInROTrait, Engine, ROConstants, ROConstantsCircuit, ROTrait, }; +/// A type that holds parameters for the primary and secondary circuits of Nova and SuperNova +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct CircuitShape { + F_arity: usize, + r1cs_shape: R1CSShape, +} + +impl SimpleDigestible for CircuitShape {} + +impl CircuitShape { + /// Create a new `CircuitShape` + pub fn new(r1cs_shape: R1CSShape, F_arity: usize) -> Self { + Self { + F_arity, + r1cs_shape, + } + } + + /// Return the [CircuitShape]' digest. + pub fn digest(&self) -> E::Scalar { + let dc: DigestComputer<'_, ::Scalar, CircuitShape> = DigestComputer::new(self); + dc.digest().expect("Failure in computing digest") + } +} + /// A type that holds public parameters of Nova #[derive(Serialize, Deserialize)] #[serde(bound = "")] @@ -67,11 +94,11 @@ where ro_consts_primary: ROConstants, ro_consts_circuit_primary: ROConstantsCircuit, ck_primary: CommitmentKey, - r1cs_shape_primary: R1CSShape, + circuit_shape_primary: CircuitShape, ro_consts_secondary: ROConstants, ro_consts_circuit_secondary: ROConstantsCircuit, ck_secondary: CommitmentKey, - r1cs_shape_secondary: R1CSShape, + circuit_shape_secondary: CircuitShape, augmented_circuit_params_primary: NovaAugmentedCircuitParams, augmented_circuit_params_secondary: NovaAugmentedCircuitParams, #[serde(skip, default = "OnceCell::new")] @@ -95,7 +122,7 @@ where C1: StepCircuit, C2: StepCircuit, { - /// Creates a new `PublicParams` for a pair of circuits `C1` and `C2`. + /// Set up builder to create `PublicParams` for a pair of circuits `C1` and `C2`. /// /// # Note /// @@ -168,7 +195,8 @@ where ); let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit_primary.synthesize(&mut cs); - let (r1cs_shape_primary, ck_primary) = cs.r1cs_shape(ck_hint1); + let (r1cs_shape_primary, ck_primary) = cs.r1cs_shape_and_key(ck_hint1); + let circuit_shape_primary = CircuitShape::new(r1cs_shape_primary, F_arity_primary); // Initialize ck for the secondary let circuit_secondary: NovaAugmentedCircuit<'_, E1, C2> = NovaAugmentedCircuit::new( @@ -179,7 +207,8 @@ where ); let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit_secondary.synthesize(&mut cs); - let (r1cs_shape_secondary, ck_secondary) = cs.r1cs_shape(ck_hint2); + let (r1cs_shape_secondary, ck_secondary) = cs.r1cs_shape_and_key(ck_hint2); + let circuit_shape_secondary = CircuitShape::new(r1cs_shape_secondary, F_arity_secondary); PublicParams { F_arity_primary, @@ -187,11 +216,11 @@ where ro_consts_primary, ro_consts_circuit_primary, ck_primary, - r1cs_shape_primary, + circuit_shape_primary, ro_consts_secondary, ro_consts_circuit_secondary, ck_secondary, - r1cs_shape_secondary, + circuit_shape_secondary, augmented_circuit_params_primary, augmented_circuit_params_secondary, digest: OnceCell::new(), @@ -211,16 +240,16 @@ where /// Returns the number of constraints in the primary and secondary circuits pub const fn num_constraints(&self) -> (usize, usize) { ( - self.r1cs_shape_primary.num_cons, - self.r1cs_shape_secondary.num_cons, + self.circuit_shape_primary.r1cs_shape.num_cons, + self.circuit_shape_secondary.r1cs_shape.num_cons, ) } /// Returns the number of variables in the primary and secondary circuits pub const fn num_variables(&self) -> (usize, usize) { ( - self.r1cs_shape_primary.num_vars, - self.r1cs_shape_secondary.num_vars, + self.circuit_shape_primary.r1cs_shape.num_vars, + self.circuit_shape_secondary.r1cs_shape.num_vars, ) } } @@ -268,6 +297,9 @@ where return Err(NovaError::InvalidInitialInputLength); } + let r1cs_primary = &pp.circuit_shape_primary.r1cs_shape; + let r1cs_secondary = &pp.circuit_shape_secondary.r1cs_shape; + // base case for the primary let mut cs_primary = SatisfyingAssignment::::new(); let inputs_primary: NovaAugmentedCircuitInputs = NovaAugmentedCircuitInputs::new( @@ -291,7 +323,7 @@ where .map_err(|_| NovaError::SynthesisError) .expect("Nova error synthesis"); let (u_primary, w_primary) = cs_primary - .r1cs_instance_and_witness(&pp.r1cs_shape_primary, &pp.ck_primary) + .r1cs_instance_and_witness(r1cs_primary, &pp.ck_primary) .map_err(|_e| NovaError::UnSat) .expect("Nova error unsat"); @@ -317,23 +349,25 @@ where .map_err(|_| NovaError::SynthesisError) .expect("Nova error synthesis"); let (u_secondary, w_secondary) = cs_secondary - .r1cs_instance_and_witness(&pp.r1cs_shape_secondary, &pp.ck_secondary) + .r1cs_instance_and_witness(&pp.circuit_shape_secondary.r1cs_shape, &pp.ck_secondary) .map_err(|_e| NovaError::UnSat) .expect("Nova error unsat"); // IVC proof for the primary circuit let l_w_primary = w_primary; let l_u_primary = u_primary; - let r_W_primary = RelaxedR1CSWitness::from_r1cs_witness(&pp.r1cs_shape_primary, &l_w_primary); - let r_U_primary = - RelaxedR1CSInstance::from_r1cs_instance(&pp.ck_primary, &pp.r1cs_shape_primary, &l_u_primary); + let r_W_primary = RelaxedR1CSWitness::from_r1cs_witness(r1cs_primary, &l_w_primary); + let r_U_primary = RelaxedR1CSInstance::from_r1cs_instance( + &pp.ck_primary, + &pp.circuit_shape_primary.r1cs_shape, + &l_u_primary, + ); // IVC proof for the secondary circuit let l_w_secondary = w_secondary; let l_u_secondary = u_secondary; - let r_W_secondary = RelaxedR1CSWitness::::default(&pp.r1cs_shape_secondary); - let r_U_secondary = - RelaxedR1CSInstance::::default(&pp.ck_secondary, &pp.r1cs_shape_secondary); + let r_W_secondary = RelaxedR1CSWitness::::default(r1cs_secondary); + let r_U_secondary = RelaxedR1CSInstance::::default(&pp.ck_secondary, r1cs_secondary); assert!( !(zi_primary.len() != pp.F_arity_primary || zi_secondary.len() != pp.F_arity_secondary), @@ -387,7 +421,7 @@ where &pp.ck_secondary, &pp.ro_consts_secondary, &scalar_as_base::(pp.digest()), - &pp.r1cs_shape_secondary, + &pp.circuit_shape_secondary.r1cs_shape, &self.r_U_secondary, &self.r_W_secondary, &self.l_u_secondary, @@ -417,7 +451,7 @@ where .map_err(|_| NovaError::SynthesisError)?; let (l_u_primary, l_w_primary) = cs_primary - .r1cs_instance_and_witness(&pp.r1cs_shape_primary, &pp.ck_primary) + .r1cs_instance_and_witness(&pp.circuit_shape_primary.r1cs_shape, &pp.ck_primary) .map_err(|_e| NovaError::UnSat) .expect("Nova error unsat"); @@ -426,7 +460,7 @@ where &pp.ck_primary, &pp.ro_consts_primary, &pp.digest(), - &pp.r1cs_shape_primary, + &pp.circuit_shape_primary.r1cs_shape, &self.r_U_primary, &self.r_W_primary, &l_u_primary, @@ -456,7 +490,7 @@ where .map_err(|_| NovaError::SynthesisError)?; let (l_u_secondary, l_w_secondary) = cs_secondary - .r1cs_instance_and_witness(&pp.r1cs_shape_secondary, &pp.ck_secondary) + .r1cs_instance_and_witness(&pp.circuit_shape_secondary.r1cs_shape, &pp.ck_secondary) .map_err(|_e| NovaError::UnSat)?; // update the running instances and witnesses @@ -558,20 +592,23 @@ where // check the satisfiability of the provided instances let (res_r_primary, (res_r_secondary, res_l_secondary)) = rayon::join( || { - pp.r1cs_shape_primary - .is_sat_relaxed(&pp.ck_primary, &self.r_U_primary, &self.r_W_primary) + pp.circuit_shape_primary.r1cs_shape.is_sat_relaxed( + &pp.ck_primary, + &self.r_U_primary, + &self.r_W_primary, + ) }, || { rayon::join( || { - pp.r1cs_shape_secondary.is_sat_relaxed( + pp.circuit_shape_secondary.r1cs_shape.is_sat_relaxed( &pp.ck_secondary, &self.r_U_secondary, &self.r_W_secondary, ) }, || { - pp.r1cs_shape_secondary.is_sat( + pp.circuit_shape_secondary.r1cs_shape.is_sat( &pp.ck_secondary, &self.l_u_secondary, &self.l_w_secondary, @@ -674,8 +711,9 @@ where ), NovaError, > { - let (pk_primary, vk_primary) = S1::setup(&pp.ck_primary, &pp.r1cs_shape_primary)?; - let (pk_secondary, vk_secondary) = S2::setup(&pp.ck_secondary, &pp.r1cs_shape_secondary)?; + let (pk_primary, vk_primary) = S1::setup(&pp.ck_primary, &pp.circuit_shape_primary.r1cs_shape)?; + let (pk_secondary, vk_secondary) = + S2::setup(&pp.ck_secondary, &pp.circuit_shape_secondary.r1cs_shape)?; let pk = ProverKey { pk_primary, @@ -708,7 +746,7 @@ where &pp.ck_secondary, &pp.ro_consts_secondary, &scalar_as_base::(pp.digest()), - &pp.r1cs_shape_secondary, + &pp.circuit_shape_secondary.r1cs_shape, &recursive_snark.r_U_secondary, &recursive_snark.r_W_secondary, &recursive_snark.l_u_secondary, @@ -721,7 +759,7 @@ where S1::prove( &pp.ck_primary, &pk.pk_primary, - &pp.r1cs_shape_primary, + &pp.circuit_shape_primary.r1cs_shape, &recursive_snark.r_U_primary, &recursive_snark.r_W_primary, ) @@ -730,7 +768,7 @@ where S2::prove( &pp.ck_secondary, &pk.pk_secondary, - &pp.r1cs_shape_secondary, + &pp.circuit_shape_secondary.r1cs_shape, &f_U_secondary, &f_W_secondary, ) @@ -956,13 +994,13 @@ mod tests { test_pp_digest_with::( &trivial_circuit1, &trivial_circuit2, - "cb581e2d5c4b2ef2ddbe2d6849e0da810352f59bcdaca51476dcf9e16072f100", + "f4a04841515b4721519e2671b7ee11e58e2d4a30bb183ded963b71ad2ec80d00", ); test_pp_digest_with::( &cubic_circuit1, &trivial_circuit2, - "3cc29bb864910463e0501bac84cdefc1d4327e9c2ef5b0fd6d45ad1741f1a401", + "dc1b7c40ab50c5c6877ad027769452870cc28f1d13f140de7ca3a00138c58502", ); let trivial_circuit1_grumpkin = TrivialCircuit::<::Scalar>::default(); @@ -973,25 +1011,25 @@ mod tests { test_pp_digest_with::( &trivial_circuit1_grumpkin, &trivial_circuit2_grumpkin, - "c4ecd363a6c1473de7e0d24fc1dbb660f563556e2e13fb4614acdff04cab7701", + "df834cb2bb401251473de2f3bbe6f6c28f1c6848f74dd8281faef91b73e7f400", ); #[cfg(feature = "asm")] test_pp_digest_with::( &cubic_circuit1_grumpkin, &trivial_circuit2_grumpkin, - "4853a6463b6309f6ae76442934d0a423f51f1e10abaddd0d39bf5644ed589100", + "97c009974d5b12b318045d623fff145006fab5f4234cb50e867d66377e191300", ); #[cfg(not(feature = "asm"))] test_pp_digest_with::( &trivial_circuit1_grumpkin, &trivial_circuit2_grumpkin, - "c26cc841d42c19bf98bc2482e66cd30903922f2a923927b85d66f375a821f101", + "c565748bea3336f07c8ff997c542ed62385ff5662f29402c4f9747153f699e01", ); #[cfg(not(feature = "asm"))] test_pp_digest_with::( &cubic_circuit1_grumpkin, &trivial_circuit2_grumpkin, - "4c484cab71e93dda69b420beb7276af969c2034a7ffb0ea8e6964e96a7e5a901", + "5aec6defcb0f6b2bb14aec70362419388916d7a5bc528c0b3fabb197ae57cb03", ); let trivial_circuit1_secp = TrivialCircuit::<::Scalar>::default(); @@ -1001,12 +1039,12 @@ mod tests { test_pp_digest_with::( &trivial_circuit1_secp, &trivial_circuit2_secp, - "b794d655fb39891eaf530ca3be1ec2a5ac97f72a0d07c45dbb84529d8a611502", + "66c6d3618bb824bcb9253b7731247b89432853bf2014ffae45a8f6b00befe303", ); test_pp_digest_with::( &cubic_circuit1_secp, &trivial_circuit2_secp, - "50e6acf363c31c2ac1c9c646b4494cb21aae6cb648c7b0d4c95015c811fba302", + "cc22c270460e11d190235fbd691bdeec51e8200219e5e65112e48bb80b610803", ); } @@ -1030,7 +1068,6 @@ mod tests { &*default_ck_hint(), &*default_ck_hint(), ); - let num_steps = 1; // produce a recursive SNARK diff --git a/src/nifs.rs b/src/nifs.rs index 881bfd126..34c947648 100644 --- a/src/nifs.rs +++ b/src/nifs.rs @@ -120,7 +120,7 @@ mod tests { test_shape_cs::TestShapeCS, }, provider::{Bn256Engine, PallasEngine, Secp256k1Engine}, - r1cs::{SparseMatrix, R1CS}, + r1cs::{commitment_key, SparseMatrix}, traits::{snark::default_ck_hint, Engine}, }; use ::bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; @@ -164,7 +164,7 @@ mod tests { // First create the shape let mut cs: TestShapeCS = TestShapeCS::new(); let _ = synthesize_tiny_r1cs_bellpepper(&mut cs, None); - let (shape, ck) = cs.r1cs_shape(&*default_ck_hint()); + let (shape, ck) = cs.r1cs_shape_and_key(&*default_ck_hint()); let ro_consts = <::RO as ROTrait<::Base, ::Scalar>>::Constants::default(); @@ -323,7 +323,7 @@ mod tests { }; // generate generators and ro constants - let ck = R1CS::::commitment_key(&S, &*default_ck_hint()); + let ck = commitment_key(&S, &*default_ck_hint()); let ro_consts = <::RO as ROTrait<::Base, ::Scalar>>::Constants::default(); diff --git a/src/r1cs/mod.rs b/src/r1cs/mod.rs index b55a24b04..b382f4860 100644 --- a/src/r1cs/mod.rs +++ b/src/r1cs/mod.rs @@ -1,7 +1,7 @@ //! This module defines R1CS related types and a folding scheme for Relaxed R1CS mod sparse; #[cfg(test)] -mod util; +pub(crate) mod util; use crate::{ constants::{BN_LIMB_WIDTH, BN_N_LIMBS}, @@ -16,7 +16,7 @@ use crate::{ }, Commitment, CommitmentKey, CE, }; -use core::{cmp::max, marker::PhantomData}; +use core::cmp::max; use ff::Field; use once_cell::sync::OnceCell; @@ -25,13 +25,6 @@ use serde::{Deserialize, Serialize}; pub(crate) use self::sparse::SparseMatrix; -/// Public parameters for a given R1CS -#[derive(Clone, Serialize, Deserialize)] -#[serde(bound = "")] -pub struct R1CS { - _p: PhantomData, -} - /// A type that holds the shape of the R1CS matrices #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct R1CSShape { @@ -78,26 +71,34 @@ pub struct RelaxedR1CSInstance { pub(crate) u: E::Scalar, } +/// A type for functions that hints commitment key sizing by returning the floor of the number of required generators. pub type CommitmentKeyHint = dyn Fn(&R1CSShape) -> usize; -impl R1CS { - /// Generates public parameters for a Rank-1 Constraint System (R1CS). - /// - /// This function takes into consideration the shape of the R1CS matrices and a hint function - /// for the number of generators. It returns a `CommitmentKey`. - /// - /// # Arguments - /// - /// * `S`: The shape of the R1CS matrices. - /// * `ck_floor`: A function that provides a floor for the number of generators. A good function - /// to provide is the ck_floor field defined in the trait `RelaxedR1CSSNARKTrait`. - /// - pub fn commitment_key(S: &R1CSShape, ck_floor: &CommitmentKeyHint) -> CommitmentKey { - let num_cons = S.num_cons; - let num_vars = S.num_vars; - let ck_hint = ck_floor(S); - E::CE::setup(b"ck", max(max(num_cons, num_vars), ck_hint)) - } +/// Generates public parameters for a Rank-1 Constraint System (R1CS). +/// +/// This function takes into consideration the shape of the R1CS matrices and a hint function +/// for the number of generators. It returns a `CommitmentKey`. +/// +/// # Arguments +/// +/// * `S`: The shape of the R1CS matrices. +/// * `ck_floor`: A function that provides a floor for the number of generators. A good function to +/// provide is the `commitment_key_floor` field in the trait `RelaxedR1CSSNARKTrait`. +/// +pub fn commitment_key( + S: &R1CSShape, + ck_floor: &CommitmentKeyHint, +) -> CommitmentKey { + let size = commitment_key_size(S, ck_floor); + E::CE::setup(b"ck", size) +} + +/// Computes the number of generators required for the commitment key corresponding to shape `S`. +pub fn commitment_key_size(S: &R1CSShape, ck_floor: &CommitmentKeyHint) -> usize { + let num_cons = S.num_cons; + let num_vars = S.num_vars; + let ck_hint = ck_floor(S); + max(max(num_cons, num_vars), ck_hint) } impl R1CSShape { @@ -166,6 +167,7 @@ impl R1CSShape { cons_valid && vars_valid && io_valid && io_lt_vars } + /// multiplies a vector with the matrix pub fn multiply_vec( &self, z: &[E::Scalar], diff --git a/src/spartan/batched.rs b/src/spartan/batched.rs new file mode 100644 index 000000000..66a0e7c23 --- /dev/null +++ b/src/spartan/batched.rs @@ -0,0 +1,621 @@ +//! This module implements `BatchedRelaxedR1CSSNARKTrait` using Spartan that is generic over the polynomial commitment +//! and evaluation argument (i.e., a PCS) This version of Spartan does not use preprocessing so the verifier keeps the +//! entire description of R1CS matrices. This is essentially optimal for the verifier when using an IPA-based polynomial +//! commitment scheme. This batched implementation batches the outer and inner sumchecks of the Spartan SNARK. + +use ff::Field; +use serde::{Deserialize, Serialize}; + +use itertools::Itertools; +use once_cell::sync::OnceCell; +use rayon::prelude::*; + +use super::{ + compute_eval_table_sparse, + math::Math, + polys::{eq::EqPolynomial, multilinear::MultilinearPolynomial}, + powers, + snark::batch_eval_prove, + sumcheck::SumcheckProof, + PolyEvalInstance, PolyEvalWitness, +}; + +use crate::{ + digest::{DigestComputer, SimpleDigestible}, + errors::NovaError, + r1cs::{R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness, SparseMatrix}, + spartan::{ + polys::{multilinear::SparsePolynomial, power::PowPolynomial}, + snark::batch_eval_verify, + }, + traits::{ + evaluation::EvaluationEngineTrait, + snark::{BatchedRelaxedR1CSSNARKTrait, DigestHelperTrait}, + Engine, TranscriptEngineTrait, + }, + zip_with, CommitmentKey, +}; + +/// A succinct proof of knowledge of a witness to a batch of relaxed R1CS instances +/// The proof is produced using Spartan's combination of the sum-check and +/// the commitment to a vector viewed as a polynomial commitment +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct BatchedRelaxedR1CSSNARK> { + sc_proof_outer: SumcheckProof, + // Claims ([Azᵢ(τᵢ)], [Bzᵢ(τᵢ)], [Czᵢ(τᵢ)]) + claims_outer: (Vec, Vec, Vec), + // [Eᵢ(r_x)] + evals_E: Vec, + sc_proof_inner: SumcheckProof, + // [Wᵢ(r_y[1..])] + evals_W: Vec, + sc_proof_batch: SumcheckProof, + // [Wᵢ(r_z), Eᵢ(r_z)] + evals_batch: Vec, + eval_arg: EE::EvaluationArgument, +} + +/// A type that represents the prover's key +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct ProverKey> { + pk_ee: EE::ProverKey, + vk_digest: E::Scalar, // digest of the verifier's key +} + +/// A type that represents the verifier's key +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct VerifierKey> { + vk_ee: EE::VerifierKey, + S: Vec>, + #[serde(skip, default = "OnceCell::new")] + digest: OnceCell, +} + +impl> VerifierKey { + fn new(shapes: Vec>, vk_ee: EE::VerifierKey) -> Self { + VerifierKey { + vk_ee, + S: shapes, + digest: OnceCell::new(), + } + } +} + +impl> SimpleDigestible for VerifierKey {} + +impl> DigestHelperTrait for VerifierKey { + /// Returns the digest of the verifier's key. + fn digest(&self) -> E::Scalar { + self + .digest + .get_or_try_init(|| { + let dc = DigestComputer::::new(self); + dc.digest() + }) + .cloned() + .expect("Failure to retrieve digest!") + } +} + +impl> BatchedRelaxedR1CSSNARKTrait + for BatchedRelaxedR1CSSNARK +{ + type ProverKey = ProverKey; + + type VerifierKey = VerifierKey; + + fn setup( + ck: &CommitmentKey, + S: Vec<&R1CSShape>, + ) -> Result<(Self::ProverKey, Self::VerifierKey), NovaError> { + let (pk_ee, vk_ee) = EE::setup(ck); + + let S = S.iter().map(|s| s.pad()).collect(); + + let vk = VerifierKey::new(S, vk_ee); + + let pk = ProverKey { + pk_ee, + vk_digest: vk.digest(), + }; + + Ok((pk, vk)) + } + + fn prove( + ck: &CommitmentKey, + pk: &Self::ProverKey, + S: Vec<&R1CSShape>, + U: &[RelaxedR1CSInstance], + W: &[RelaxedR1CSWitness], + ) -> Result { + let num_instances = U.len(); + // Pad shapes and ensure their sizes are correct + let S = S + .iter() + .map(|s| { + let s = s.pad(); + if s.is_regular_shape() { + Ok(s) + } else { + Err(NovaError::InternalError) + } + }) + .collect::, _>>()?; + + // Pad (W,E) for each instance + let W = zip_with!(iter, (W, S), |w, s| w.pad(s)).collect::>>(); + + let mut transcript = E::TE::new(b"BatchedRelaxedR1CSSNARK"); + + transcript.absorb(b"vk", &pk.vk_digest); + if num_instances > 1 { + let num_instances_field = E::Scalar::from(num_instances as u64); + transcript.absorb(b"n", &num_instances_field); + } + U.iter().for_each(|u| { + transcript.absorb(b"U", u); + }); + + let (polys_W, polys_E): (Vec<_>, Vec<_>) = W.into_iter().map(|w| (w.W, w.E)).unzip(); + + // Append public inputs to W: Z = [W, u, X] + let polys_Z = zip_with!(iter, (polys_W, U), |w, u| [ + w.clone(), + vec![u.u], + u.X.clone() + ] + .concat()) + .collect::>>(); + + let (num_rounds_x, num_rounds_y): (Vec<_>, Vec<_>) = S + .iter() + .map(|s| (s.num_cons.log_2(), s.num_vars.log_2() + 1)) + .unzip(); + let num_rounds_x_max = *num_rounds_x.iter().max().unwrap(); + let num_rounds_y_max = *num_rounds_y.iter().max().unwrap(); + + // Generate tau polynomial corresponding to eq(τ, τ², τ⁴ , …) + // for a random challenge τ + let tau = transcript.squeeze(b"t")?; + let all_taus = PowPolynomial::squares(&tau, num_rounds_x_max); + + let polys_tau = num_rounds_x + .iter() + .map(|&num_rounds_x| PowPolynomial::evals_with_powers(&all_taus, num_rounds_x)) + .map(MultilinearPolynomial::new) + .collect::>(); + + // Compute MLEs of Az, Bz, Cz, uCz + E + let (polys_Az, polys_Bz, polys_Cz): (Vec<_>, Vec<_>, Vec<_>) = + zip_with!(par_iter, (S, polys_Z), |s, poly_Z| { + let (poly_Az, poly_Bz, poly_Cz) = s.multiply_vec(poly_Z)?; + Ok((poly_Az, poly_Bz, poly_Cz)) + }) + .collect::, NovaError>>()? + .into_iter() + .multiunzip(); + + let polys_uCz_E = zip_with!(par_iter, (U, polys_E, polys_Cz), |u, poly_E, poly_Cz| { + zip_with!(par_iter, (poly_Cz, poly_E), |cz, e| u.u * cz + e).collect::>() + }) + .collect::>(); + + let comb_func_outer = + |poly_A_comp: &E::Scalar, + poly_B_comp: &E::Scalar, + poly_C_comp: &E::Scalar, + poly_D_comp: &E::Scalar| + -> E::Scalar { *poly_A_comp * (*poly_B_comp * *poly_C_comp - *poly_D_comp) }; + + // Sample challenge for random linear-combination of outer claims + let outer_r = transcript.squeeze(b"out_r")?; + let outer_r_powers = powers::(&outer_r, num_instances); + + // Verify outer sumcheck: Az * Bz - uCz_E for each instance + let (sc_proof_outer, r_x, claims_outer) = SumcheckProof::prove_cubic_with_additive_term_batch( + &vec![E::Scalar::ZERO; num_instances], + &num_rounds_x, + polys_tau, + polys_Az + .into_iter() + .map(MultilinearPolynomial::new) + .collect(), + polys_Bz + .into_iter() + .map(MultilinearPolynomial::new) + .collect(), + polys_uCz_E + .into_iter() + .map(MultilinearPolynomial::new) + .collect(), + &outer_r_powers, + comb_func_outer, + &mut transcript, + )?; + + let r_x = num_rounds_x + .iter() + .map(|&num_rounds| r_x[(num_rounds_x_max - num_rounds)..].to_vec()) + .collect::>(); + + // Extract evaluations of Az, Bz from Sumcheck and Cz, E at r_x + let (evals_Az_Bz_Cz, evals_E): (Vec<_>, Vec<_>) = zip_with!( + par_iter, + (claims_outer[1], claims_outer[2], polys_Cz, polys_E, r_x), + |eval_Az, eval_Bz, poly_Cz, poly_E, r_x| { + let (eval_Cz, eval_E) = rayon::join( + || MultilinearPolynomial::evaluate_with(poly_Cz, r_x), + || MultilinearPolynomial::evaluate_with(poly_E, r_x), + ); + ((*eval_Az, *eval_Bz, eval_Cz), eval_E) + } + ) + .unzip(); + + evals_Az_Bz_Cz.iter().zip_eq(evals_E.iter()).for_each( + |(&(eval_Az, eval_Bz, eval_Cz), &eval_E)| { + transcript.absorb( + b"claims_outer", + &[eval_Az, eval_Bz, eval_Cz, eval_E].as_slice(), + ) + }, + ); + + let inner_r = transcript.squeeze(b"in_r")?; + let inner_r_square = inner_r.square(); + let inner_r_cube = inner_r_square * inner_r; + let inner_r_powers = powers::(&inner_r_cube, num_instances); + + let claims_inner_joint = evals_Az_Bz_Cz + .iter() + .map(|(eval_Az, eval_Bz, eval_Cz)| *eval_Az + inner_r * eval_Bz + inner_r_square * eval_Cz) + .collect::>(); + + let polys_ABCs = { + let inner = |M_evals_As: Vec, + M_evals_Bs: Vec, + M_evals_Cs: Vec| + -> Vec { + zip_with!( + into_par_iter, + (M_evals_As, M_evals_Bs, M_evals_Cs), + |eval_A, eval_B, eval_C| eval_A + inner_r * eval_B + inner_r_square * eval_C + ) + .collect::>() + }; + + zip_with!(par_iter, (S, r_x), |s, r_x| { + let evals_rx = EqPolynomial::evals_from_points(r_x); + let (eval_A, eval_B, eval_C) = compute_eval_table_sparse(s, &evals_rx); + MultilinearPolynomial::new(inner(eval_A, eval_B, eval_C)) + }) + .collect::>() + }; + + let polys_Z = polys_Z + .into_iter() + .zip_eq(num_rounds_y.iter()) + .map(|(mut z, &num_rounds_y)| { + z.resize(1 << num_rounds_y, E::Scalar::ZERO); + MultilinearPolynomial::new(z) + }) + .collect::>(); + + let comb_func = |poly_A_comp: &E::Scalar, poly_B_comp: &E::Scalar| -> E::Scalar { + *poly_A_comp * *poly_B_comp + }; + + let (sc_proof_inner, r_y, _claims_inner): (SumcheckProof, Vec, (Vec<_>, Vec<_>)) = + SumcheckProof::prove_quad_batch( + &claims_inner_joint, + &num_rounds_y, + polys_ABCs, + polys_Z, + &inner_r_powers, + comb_func, + &mut transcript, + )?; + + let r_y = num_rounds_y + .iter() + .map(|num_rounds| { + let (_, r_y_hi) = r_y.split_at(num_rounds_y_max - num_rounds); + r_y_hi + }) + .collect::>(); + + let evals_W = zip_with!(par_iter, (polys_W, r_y), |poly, r_y| { + MultilinearPolynomial::evaluate_with(poly, &r_y[1..]) + }) + .collect::>(); + + // Create evaluation instances for W(r_y[1..]) and E(r_x) + let (w_vec, u_vec) = { + let mut w_vec = Vec::with_capacity(2 * num_instances); + let mut u_vec = Vec::with_capacity(2 * num_instances); + w_vec.extend(polys_W.into_iter().map(|poly| PolyEvalWitness { p: poly })); + u_vec.extend(zip_with!(iter, (evals_W, U, r_y), |eval, u, r_y| { + PolyEvalInstance { + c: u.comm_W, + x: r_y[1..].to_vec(), + e: *eval, + } + })); + + w_vec.extend(polys_E.into_iter().map(|poly| PolyEvalWitness { p: poly })); + u_vec.extend(zip_with!( + (evals_E.iter(), U.iter(), r_x), + |eval_E, u, r_x| PolyEvalInstance { + c: u.comm_E, + x: r_x, + e: *eval_E, + } + )); + (w_vec, u_vec) + }; + + let (batched_u, batched_w, sc_proof_batch, claims_batch_left) = + batch_eval_prove(u_vec, w_vec, &mut transcript)?; + + let eval_arg = EE::prove( + ck, + &pk.pk_ee, + &mut transcript, + &batched_u.c, + &batched_w.p, + &batched_u.x, + &batched_u.e, + )?; + + let (evals_Az, evals_Bz, evals_Cz): (Vec<_>, Vec<_>, Vec<_>) = + evals_Az_Bz_Cz.into_iter().multiunzip(); + + Ok(BatchedRelaxedR1CSSNARK { + sc_proof_outer, + claims_outer: (evals_Az, evals_Bz, evals_Cz), + evals_E, + sc_proof_inner, + evals_W, + sc_proof_batch, + evals_batch: claims_batch_left, + eval_arg, + }) + } + + fn verify(&self, vk: &Self::VerifierKey, U: &[RelaxedR1CSInstance]) -> Result<(), NovaError> { + let num_instances = U.len(); + let mut transcript = E::TE::new(b"BatchedRelaxedR1CSSNARK"); + + transcript.absorb(b"vk", &vk.digest()); + if num_instances > 1 { + let num_instances_field = E::Scalar::from(num_instances as u64); + transcript.absorb(b"n", &num_instances_field); + } + U.iter().for_each(|u| { + transcript.absorb(b"U", u); + }); + + let num_instances = U.len(); + + let (num_rounds_x, num_rounds_y): (Vec<_>, Vec<_>) = vk + .S + .iter() + .map(|s| (s.num_cons.log_2(), s.num_vars.log_2() + 1)) + .unzip(); + let num_rounds_x_max = *num_rounds_x.iter().max().unwrap(); + let num_rounds_y_max = *num_rounds_y.iter().max().unwrap(); + + // Define τ polynomials of the appropriate size for each instance + let polys_tau = { + let tau = transcript.squeeze(b"t")?; + + num_rounds_x + .iter() + .map(|&num_rounds| PowPolynomial::new(&tau, num_rounds)) + .collect::>() + }; + + // Sample challenge for random linear-combination of outer claims + let outer_r = transcript.squeeze(b"out_r")?; + let outer_r_powers = powers::(&outer_r, num_instances); + + let (claim_outer_final, r_x) = self.sc_proof_outer.verify_batch( + &vec![E::Scalar::ZERO; num_instances], + &num_rounds_x, + &outer_r_powers, + 3, + &mut transcript, + )?; + + // Since each instance has a different number of rounds, the Sumcheck + // prover skips the first num_rounds_x_max - num_rounds_x rounds. + // The evaluation point for each instance is therefore r_x[num_rounds_x_max - num_rounds_x..] + let r_x = num_rounds_x + .iter() + .map(|num_rounds| r_x[(num_rounds_x_max - num_rounds)..].to_vec()) + .collect::>(); + + // Extract evaluations into a vector [(Azᵢ, Bzᵢ, Czᵢ, Eᵢ)] + // TODO: This is a multizip, simplify + let ABCE_evals = zip_with!( + iter, + ( + self.evals_E, + self.claims_outer.0, + self.claims_outer.1, + self.claims_outer.2 + ), + |eval_E, claim_Az, claim_Bz, claim_Cz| (*claim_Az, *claim_Bz, *claim_Cz, *eval_E) + ) + .collect::>(); + + // Add evaluations of Az, Bz, Cz, E to transcript + ABCE_evals + .iter() + .for_each(|(claim_Az, claim_Bz, claim_Cz, eval_E)| { + transcript.absorb( + b"claims_outer", + &[*claim_Az, *claim_Bz, *claim_Cz, *eval_E].as_slice(), + ) + }); + + // Evaluate τ(rₓ) for each instance + let evals_tau = zip_with!(iter, (polys_tau, r_x), |poly_tau, r_x| poly_tau + .evaluate(r_x)); + + // Compute expected claim for all instances ∑ᵢ rⁱ⋅τ(rₓ)⋅(Azᵢ⋅Bzᵢ − uᵢ⋅Czᵢ − Eᵢ) + let claim_outer_final_expected = zip_with!( + ( + ABCE_evals.iter().copied(), + U.iter(), + evals_tau, + outer_r_powers.iter() + ), + |ABCE_eval, u, eval_tau, r| { + let (claim_Az, claim_Bz, claim_Cz, eval_E) = ABCE_eval; + *r * eval_tau * (claim_Az * claim_Bz - u.u * claim_Cz - eval_E) + } + ) + .sum::(); + + if claim_outer_final != claim_outer_final_expected { + return Err(NovaError::InvalidSumcheckProof); + } + + let inner_r = transcript.squeeze(b"in_r")?; + let inner_r_square = inner_r.square(); + let inner_r_cube = inner_r_square * inner_r; + let inner_r_powers = powers::(&inner_r_cube, num_instances); + + // Compute inner claims Mzᵢ = (Azᵢ + r⋅Bzᵢ + r²⋅Czᵢ), + // which are batched by Sumcheck into one claim: ∑ᵢ r³ⁱ⋅Mzᵢ + let claims_inner = ABCE_evals + .into_iter() + .map(|(claim_Az, claim_Bz, claim_Cz, _)| { + claim_Az + inner_r * claim_Bz + inner_r_square * claim_Cz + }) + .collect::>(); + + let (claim_inner_final, r_y) = self.sc_proof_inner.verify_batch( + &claims_inner, + &num_rounds_y, + &inner_r_powers, + 2, + &mut transcript, + )?; + let r_y: Vec> = num_rounds_y + .iter() + .map(|num_rounds| r_y[(num_rounds_y_max - num_rounds)..].to_vec()) + .collect(); + + // Compute evaluations of Zᵢ = [Wᵢ, uᵢ, Xᵢ] at r_y + // Zᵢ(r_y) = (1−r_y[0])⋅W(r_y[1..]) + r_y[0]⋅MLE([uᵢ, Xᵢ])(r_y[1..]) + let evals_Z = zip_with!(iter, (self.evals_W, U, r_y), |eval_W, U, r_y| { + let eval_X = { + // constant term + let mut poly_X = vec![(0, U.u)]; + //remaining inputs + poly_X.extend( + U.X + .iter() + .enumerate() + .map(|(i, x_i)| (i + 1, *x_i)) + .collect::>(), + ); + SparsePolynomial::new(r_y.len() - 1, poly_X).evaluate(&r_y[1..]) + }; + (E::Scalar::ONE - r_y[0]) * eval_W + r_y[0] * eval_X + }) + .collect::>(); + + // compute evaluations of R1CS matrices M(r_x, r_y) = eq(r_y)ᵀ⋅M⋅eq(r_x) + let multi_evaluate = |M_vec: &[&SparseMatrix], + r_x: &[E::Scalar], + r_y: &[E::Scalar]| + -> Vec { + let evaluate_with_table = + // TODO(@winston-h-zhang): review + |M: &SparseMatrix, T_x: &[E::Scalar], T_y: &[E::Scalar]| -> E::Scalar { + M.indptr + .par_windows(2) + .enumerate() + .map(|(row_idx, ptrs)| { + M.get_row_unchecked(ptrs.try_into().unwrap()) + .map(|(val, col_idx)| T_x[row_idx] * T_y[*col_idx] * val) + .sum::() + }) + .sum() + }; + + let (T_x, T_y) = rayon::join( + || EqPolynomial::evals_from_points(r_x), + || EqPolynomial::evals_from_points(r_y), + ); + + M_vec + .par_iter() + .map(|&M_vec| evaluate_with_table(M_vec, &T_x, &T_y)) + .collect() + }; + + // Compute inner claim ∑ᵢ r³ⁱ⋅(Aᵢ(r_x, r_y) + r⋅Bᵢ(r_x, r_y) + r²⋅Cᵢ(r_x, r_y))⋅Zᵢ(r_y) + let claim_inner_final_expected = zip_with!( + iter, + (vk.S, r_x, r_y, evals_Z, inner_r_powers), + |S, r_x, r_y, eval_Z, r_i| { + let evals = multi_evaluate(&[&S.A, &S.B, &S.C], r_x, r_y); + let eval = evals[0] + inner_r * evals[1] + inner_r_square * evals[2]; + eval * r_i * eval_Z + } + ) + .sum::(); + + if claim_inner_final != claim_inner_final_expected { + return Err(NovaError::InvalidSumcheckProof); + } + + // Create evaluation instances for W(r_y[1..]) and E(r_x) + let u_vec = { + let mut u_vec = Vec::with_capacity(2 * num_instances); + u_vec.extend(zip_with!(iter, (self.evals_W, U, r_y), |eval, u, r_y| { + PolyEvalInstance { + c: u.comm_W, + x: r_y[1..].to_vec(), + e: *eval, + } + })); + + u_vec.extend(zip_with!(iter, (self.evals_E, U, r_x), |eval, u, r_x| { + PolyEvalInstance { + c: u.comm_E, + x: r_x.to_vec(), + e: *eval, + } + })); + u_vec + }; + + let batched_u = batch_eval_verify( + u_vec, + &mut transcript, + &self.sc_proof_batch, + &self.evals_batch, + )?; + + // verify + EE::verify( + &vk.vk_ee, + &mut transcript, + &batched_u.c, + &batched_u.x, + &batched_u.e, + &self.eval_arg, + )?; + + Ok(()) + } +} diff --git a/src/spartan/batched_ppsnark.rs b/src/spartan/batched_ppsnark.rs new file mode 100644 index 000000000..71df552bf --- /dev/null +++ b/src/spartan/batched_ppsnark.rs @@ -0,0 +1,1352 @@ +//! batched pp snark +//! +//! + +use crate::{ + digest::{DigestComputer, SimpleDigestible}, + errors::NovaError, + r1cs::{R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness}, + spartan::{ + math::Math, + polys::{ + eq::EqPolynomial, + identity::IdentityPolynomial, + masked_eq::MaskedEqPolynomial, + multilinear::MultilinearPolynomial, + multilinear::SparsePolynomial, + power::PowPolynomial, + univariate::{CompressedUniPoly, UniPoly}, + }, + powers, + ppsnark::{ + InnerSumcheckInstance, MemorySumcheckInstance, OuterSumcheckInstance, + R1CSShapeSparkCommitment, R1CSShapeSparkRepr, SumcheckEngine, WitnessBoundSumcheck, + }, + sumcheck::SumcheckProof, + PolyEvalInstance, PolyEvalWitness, + }, + traits::{ + commitment::{CommitmentEngineTrait, CommitmentTrait, Len}, + evaluation::EvaluationEngineTrait, + snark::{BatchedRelaxedR1CSSNARKTrait, DigestHelperTrait}, + Engine, TranscriptEngineTrait, + }, + zip_with, zip_with_for_each, Commitment, CommitmentKey, CompressedCommitment, +}; +use ff::Field; +use itertools::{chain, Itertools as _}; +use once_cell::sync::*; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; + +/// A type that represents the prover's key +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct ProverKey> { + pk_ee: EE::ProverKey, + S_repr: Vec>, + S_comm: Vec>, + vk_digest: E::Scalar, // digest of verifier's key +} + +/// A type that represents the verifier's key +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct VerifierKey> { + vk_ee: EE::VerifierKey, + S_comm: Vec>, + num_vars: Vec, + #[serde(skip, default = "OnceCell::new")] + digest: OnceCell, +} +impl> VerifierKey { + fn new( + num_vars: Vec, + S_comm: Vec>, + vk_ee: EE::VerifierKey, + ) -> Self { + VerifierKey { + num_vars, + S_comm, + vk_ee, + digest: Default::default(), + } + } +} +impl> SimpleDigestible for VerifierKey {} + +impl> DigestHelperTrait for VerifierKey { + /// Returns the digest of the verifier's key + fn digest(&self) -> E::Scalar { + self + .digest + .get_or_try_init(|| { + let dc = DigestComputer::new(self); + dc.digest() + }) + .cloned() + .expect("Failure to retrieve digest!") + } +} + +/// A succinct proof of knowledge of a witness to a relaxed R1CS instance +/// The proof is produced using Spartan's combination of the sum-check and +/// the commitment to a vector viewed as a polynomial commitment +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct BatchedRelaxedR1CSSNARK> { + // commitment to oracles: the first three are for Az, Bz, Cz, + // and the last two are for memory reads + comms_Az_Bz_Cz: Vec<[CompressedCommitment; 3]>, + comms_L_row_col: Vec<[CompressedCommitment; 2]>, + // commitments to aid the memory checks + // [t_plus_r_inv_row, w_plus_r_inv_row, t_plus_r_inv_col, w_plus_r_inv_col] + comms_mem_oracles: Vec<[CompressedCommitment; 4]>, + + // claims about Az, Bz, and Cz polynomials + evals_Az_Bz_Cz_at_tau: Vec<[E::Scalar; 3]>, + + // sum-check + sc: SumcheckProof, + + // claims from the end of sum-check + evals_Az_Bz_Cz_W_E: Vec<[E::Scalar; 5]>, + evals_L_row_col: Vec<[E::Scalar; 2]>, + // [t_plus_r_inv_row, w_plus_r_inv_row, t_plus_r_inv_col, w_plus_r_inv_col] + evals_mem_oracle: Vec<[E::Scalar; 4]>, + // [val_A, val_B, val_C, row, col, ts_row, ts_col] + evals_mem_preprocessed: Vec<[E::Scalar; 7]>, + + // a PCS evaluation argument + eval_arg: EE::EvaluationArgument, +} + +impl> BatchedRelaxedR1CSSNARKTrait + for BatchedRelaxedR1CSSNARK +{ + type ProverKey = ProverKey; + type VerifierKey = VerifierKey; + + fn ck_floor() -> Box Fn(&'a R1CSShape) -> usize> { + Box::new(|shape: &R1CSShape| -> usize { + // the commitment key should be large enough to commit to the R1CS matrices + std::cmp::max( + shape.A.len() + shape.B.len() + shape.C.len(), + std::cmp::max(shape.num_cons, 2 * shape.num_vars), + ) + }) + } + + fn setup( + ck: &CommitmentKey, + S: Vec<&R1CSShape>, + ) -> Result<(Self::ProverKey, Self::VerifierKey), NovaError> { + for s in S.iter() { + // check the provided commitment key meets minimal requirements + if ck.length() < Self::ck_floor()(s) { + // return Err(NovaError::InvalidCommitmentKeyLength); + return Err(NovaError::InternalError); + } + } + let (pk_ee, vk_ee) = EE::setup(ck); + + let S = S.iter().map(|s| s.pad()).collect::>(); + let S_repr = S.iter().map(R1CSShapeSparkRepr::new).collect::>(); + let S_comm = S_repr + .iter() + .map(|s_repr| s_repr.commit(ck)) + .collect::>(); + let num_vars = S.iter().map(|s| s.num_vars).collect::>(); + let vk = VerifierKey::new(num_vars, S_comm.clone(), vk_ee); + let pk = ProverKey { + pk_ee, + S_repr, + S_comm, + vk_digest: vk.digest(), + }; + Ok((pk, vk)) + } + + fn prove( + ck: &CommitmentKey, + pk: &Self::ProverKey, + S: Vec<&R1CSShape>, + U: &[RelaxedR1CSInstance], + W: &[RelaxedR1CSWitness], + ) -> Result { + // Pad shapes so that num_vars = num_cons = Nᵢ and check the sizes are correct + let S = S + .par_iter() + .map(|s| { + let s = s.pad(); + if s.is_regular_shape() { + Ok(s) + } else { + Err(NovaError::InternalError) + } + }) + .collect::, _>>()?; + + // N[i] = max(|Aᵢ|+|Bᵢ|+|Cᵢ|, 2*num_varsᵢ, num_consᵢ) + let N = pk.S_repr.iter().map(|s| s.N).collect::>(); + assert!(N.iter().all(|&Ni| Ni.is_power_of_two())); + + let num_instances = U.len(); + + // Pad [(Wᵢ,Eᵢ)] to the next power of 2 (not to Ni) + let W = zip_with!(par_iter, (W, S), |w, s| w.pad(s)).collect::>>(); + + // number of rounds of sum-check + let num_rounds_sc = N.iter().max().unwrap().log_2(); + + // Initialize transcript with vk || [Uᵢ] + let mut transcript = E::TE::new(b"BatchedRelaxedR1CSSNARK"); + transcript.absorb(b"vk", &pk.vk_digest); + if num_instances > 1 { + let num_instances_field = E::Scalar::from(num_instances as u64); + transcript.absorb(b"n", &num_instances_field); + } + U.iter().for_each(|u| { + transcript.absorb(b"U", u); + }); + + // Append public inputs to Wᵢ: Zᵢ = [Wᵢ, uᵢ, Xᵢ] + let polys_Z = zip_with!(par_iter, (W, U, N), |W, U, Ni| { + // poly_Z will be resized later, so we preallocate the correct capacity + let mut poly_Z = Vec::with_capacity(*Ni); + poly_Z.extend(W.W.iter().chain([&U.u]).chain(U.X.iter())); + poly_Z + }) + .collect::>>(); + + // Move polys_W and polys_E, as well as U.u out of U + let (comms_W_E, us): (Vec<_>, Vec<_>) = U.iter().map(|U| ([U.comm_W, U.comm_E], U.u)).unzip(); + let (polys_W, polys_E): (Vec<_>, Vec<_>) = W.into_iter().map(|w| (w.W, w.E)).unzip(); + + // Compute [Az, Bz, Cz] + let mut polys_Az_Bz_Cz = zip_with!(par_iter, (polys_Z, S), |z, s| { + let (Az, Bz, Cz) = s.multiply_vec(z)?; + Ok([Az, Bz, Cz]) + }) + .collect::, NovaError>>()?; + + // Commit to [Az, Bz, Cz] and add to transcript + let comms_Az_Bz_Cz = polys_Az_Bz_Cz + .par_iter() + .map(|[Az, Bz, Cz]| { + let (comm_Az, (comm_Bz, comm_Cz)) = rayon::join( + || E::CE::commit(ck, Az), + || rayon::join(|| E::CE::commit(ck, Bz), || E::CE::commit(ck, Cz)), + ); + [comm_Az, comm_Bz, comm_Cz] + }) + .collect::>(); + comms_Az_Bz_Cz + .iter() + .for_each(|comms| transcript.absorb(b"c", &comms.as_slice())); + + // Compute eq(tau) for each instance in log2(Ni) variables + let tau = transcript.squeeze(b"t")?; + let (polys_tau, coords_tau): (Vec<_>, Vec<_>) = N + .iter() + .map(|&N_i| { + let log_Ni = N_i.log_2(); + let poly = PowPolynomial::new(&tau, log_Ni); + let evals = poly.evals(); + let coords = poly.coordinates(); + (evals, coords) + }) + .unzip(); + + // Pad [Az, Bz, Cz] to Ni + polys_Az_Bz_Cz + .par_iter_mut() + .zip_eq(N.par_iter()) + .for_each(|(az_bz_cz, &Ni)| { + az_bz_cz + .par_iter_mut() + .for_each(|mz| mz.resize(Ni, E::Scalar::ZERO)) + }); + + // Evaluate and commit to [Az(tau), Bz(tau), Cz(tau)] + let evals_Az_Bz_Cz_at_tau = zip_with!( + par_iter, + (polys_Az_Bz_Cz, coords_tau), + |ABCs, tau_coords| { + let [Az, Bz, Cz] = ABCs; + let (eval_Az, (eval_Bz, eval_Cz)) = rayon::join( + || MultilinearPolynomial::evaluate_with(Az, tau_coords), + || { + rayon::join( + || MultilinearPolynomial::evaluate_with(Bz, tau_coords), + || MultilinearPolynomial::evaluate_with(Cz, tau_coords), + ) + }, + ); + [eval_Az, eval_Bz, eval_Cz] + } + ) + .collect::>(); + + // absorb the claimed evaluations into the transcript + evals_Az_Bz_Cz_at_tau.iter().for_each(|evals| { + transcript.absorb(b"e", &evals.as_slice()); + }); + + // Pad Zᵢ, E to Nᵢ + let polys_Z = polys_Z + .into_par_iter() + .zip_eq(N.par_iter()) + .map(|(mut poly_Z, &Ni)| { + poly_Z.resize(Ni, E::Scalar::ZERO); + poly_Z + }) + .collect::>(); + + // Pad both W,E to have the same size. This is inefficient for W since the second half is empty, + // but it makes it easier to handle the batching at the end. + let polys_E = polys_E + .into_par_iter() + .zip_eq(N.par_iter()) + .map(|(mut poly_E, &Ni)| { + poly_E.resize(Ni, E::Scalar::ZERO); + poly_E + }) + .collect::>(); + + let polys_W = polys_W + .into_par_iter() + .zip_eq(N.par_iter()) + .map(|(mut poly_W, &Ni)| { + poly_W.resize(Ni, E::Scalar::ZERO); + poly_W + }) + .collect::>(); + + // (2) send commitments to the following two oracles + // L_row(i) = eq(tau, row(i)) for all i in [0..Nᵢ] + // L_col(i) = z(col(i)) for all i in [0..Nᵢ] + let polys_L_row_col = zip_with!( + par_iter, + (S, N, polys_Z, polys_tau), + |S, Ni, poly_Z, poly_tau| { + let mut L_row = vec![poly_tau[0]; *Ni]; // we place mem_row[0] since resized row is appended with 0s + let mut L_col = vec![poly_Z[Ni - 1]; *Ni]; // we place mem_col[Ni-1] since resized col is appended with Ni-1 + + for (i, (val_r, val_c)) in S + .A + .iter() + .chain(S.B.iter()) + .chain(S.C.iter()) + .map(|(r, c, _)| (poly_tau[r], poly_Z[c])) + .enumerate() + { + L_row[i] = val_r; + L_col[i] = val_c; + } + + [L_row, L_col] + } + ) + .collect::>(); + + let comms_L_row_col = polys_L_row_col + .par_iter() + .map(|[L_row, L_col]| { + let (comm_L_row, comm_L_col) = + rayon::join(|| E::CE::commit(ck, L_row), || E::CE::commit(ck, L_col)); + [comm_L_row, comm_L_col] + }) + .collect::>(); + + // absorb commitments to L_row and L_col in the transcript + comms_L_row_col.iter().for_each(|comms| { + transcript.absorb(b"e", &comms.as_slice()); + }); + + // For each instance, batch Mz = Az + c*Bz + c^2*Cz + let c = transcript.squeeze(b"c")?; + + let polys_Mz: Vec<_> = polys_Az_Bz_Cz + .par_iter() + .map(|polys_Az_Bz_Cz| { + let poly_vec: Vec<&Vec<_>> = polys_Az_Bz_Cz.iter().collect(); + let w = PolyEvalWitness::::batch(&poly_vec[..], &c); + w.p + }) + .collect(); + + let evals_Mz: Vec<_> = zip_with!( + iter, + (comms_Az_Bz_Cz, evals_Az_Bz_Cz_at_tau), + |comm_Az_Bz_Cz, evals_Az_Bz_Cz_at_tau| { + let u = PolyEvalInstance::::batch( + comm_Az_Bz_Cz.as_slice(), + &[], // ignored by the function + evals_Az_Bz_Cz_at_tau.as_slice(), + &c, + ); + u.e + } + ) + .collect(); + + // we now need to prove three claims for each instance + // (outer) + // 0 = \sum_x poly_tau(x) * (poly_Az(x) * poly_Bz(x) - poly_uCz_E(x)) + // eval_Az_at_tau + c * eval_Bz_at_tau + c^2 * eval_Cz_at_tau = (Az+c*Bz+c^2*Cz)(tau) + // (inner) + // eval_Az_at_tau + c * eval_Bz_at_tau + c^2 * eval_Cz_at_tau = \sum_y L_row(y) * (val_A(y) + c * val_B(y) + c^2 * val_C(y)) * L_col(y) + // (mem) + // L_row(i) = eq(tau, row(i)) + // L_col(i) = z(col(i)) + let outer_sc_inst = zip_with!( + ( + polys_Az_Bz_Cz.par_iter(), + polys_E.par_iter(), + polys_Mz.into_par_iter(), + polys_tau.par_iter(), + evals_Mz.par_iter(), + us.par_iter() + ), + |poly_ABC, poly_E, poly_Mz, poly_tau, eval_Mz, u| { + let [poly_Az, poly_Bz, poly_Cz] = poly_ABC; + let poly_uCz_E = zip_with!(par_iter, (poly_Cz, poly_E), |cz, e| *u * cz + e).collect(); + OuterSumcheckInstance::new( + poly_tau.clone(), + poly_Az.clone(), + poly_Bz.clone(), + poly_uCz_E, + poly_Mz, // Mz = Az + c * Bz + c^2 * Cz + eval_Mz, // eval_Az_at_tau + c * eval_Az_at_tau + c^2 * eval_Cz_at_tau + ) + } + ) + .collect::>(); + + let inner_sc_inst = zip_with!( + par_iter, + (pk.S_repr, evals_Mz, polys_L_row_col), + |s_repr, eval_Mz, poly_L| { + let [poly_L_row, poly_L_col] = poly_L; + let c_square = c.square(); + let val = zip_with!( + par_iter, + (s_repr.val_A, s_repr.val_B, s_repr.val_C), + |v_a, v_b, v_c| *v_a + c * *v_b + c_square * *v_c + ) + .collect::>(); + + InnerSumcheckInstance::new( + *eval_Mz, + MultilinearPolynomial::new(poly_L_row.clone()), + MultilinearPolynomial::new(poly_L_col.clone()), + MultilinearPolynomial::new(val), + ) + } + ) + .collect::>(); + + // a third sum-check instance to prove the read-only memory claim + // we now need to prove that L_row and L_col are well-formed + let (mem_sc_inst, comms_mem_oracles, polys_mem_oracles) = { + let gamma = transcript.squeeze(b"g")?; + let r = transcript.squeeze(b"r")?; + + // We start by computing oracles and auxiliary polynomials to help prove the claim + // oracles correspond to [t_plus_r_inv_row, w_plus_r_inv_row, t_plus_r_inv_col, w_plus_r_inv_col] + let (comms_mem_oracles, polys_mem_oracles, mem_aux) = pk + .S_repr + .iter() + .zip_eq(polys_tau.iter()) + .zip_eq(polys_Z.iter()) + .zip_eq(polys_L_row_col.iter()) + .try_fold( + (Vec::new(), Vec::new(), Vec::new()), + |(mut comms, mut polys, mut aux), (((s_repr, poly_tau), poly_Z), [L_row, L_col])| { + let (comm, poly, a) = MemorySumcheckInstance::::compute_oracles( + ck, + &r, + &gamma, + poly_tau, + &s_repr.row, + L_row, + &s_repr.ts_row, + poly_Z, + &s_repr.col, + L_col, + &s_repr.ts_col, + )?; + + comms.push(comm); + polys.push(poly); + aux.push(a); + + Ok::<_, NovaError>((comms, polys, aux)) + }, + )?; + + // Commit to oracles + comms_mem_oracles.iter().for_each(|comms| { + transcript.absorb(b"l", &comms.as_slice()); + }); + + // Sample new random variable for eq polynomial + let rho = transcript.squeeze(b"r")?; + let N_max = N.iter().max().unwrap(); + let all_rhos = PowPolynomial::squares(&rho, N_max.log_2()); + + let instances = zip_with!( + ( + pk.S_repr.par_iter(), + N.par_iter(), + polys_mem_oracles.par_iter(), + mem_aux.into_par_iter() + ), + |s_repr, Ni, polys_mem_oracles, polys_aux| { + MemorySumcheckInstance::::new( + polys_mem_oracles.clone(), + polys_aux, + PowPolynomial::evals_with_powers(&all_rhos, Ni.log_2()), + s_repr.ts_row.clone(), + s_repr.ts_col.clone(), + ) + } + ) + .collect::>(); + (instances, comms_mem_oracles, polys_mem_oracles) + }; + + let witness_sc_inst = zip_with!(par_iter, (polys_W, S), |poly_W, S| { + WitnessBoundSumcheck::new(tau, poly_W.clone(), S.num_vars) + }) + .collect::>(); + + // Run batched Sumcheck for the 3 claims for all instances. + // Note that the polynomials for claims relating to instance i have size Ni. + let (sc, rand_sc, claims_outer, claims_inner, claims_mem, claims_witness) = Self::prove_helper( + num_rounds_sc, + mem_sc_inst, + outer_sc_inst, + inner_sc_inst, + witness_sc_inst, + &mut transcript, + )?; + + let (evals_Az_Bz_Cz_W_E, evals_L_row_col, evals_mem_oracle, evals_mem_preprocessed) = { + let evals_Az_Bz = claims_outer + .into_iter() + .map(|claims| [claims[0][0], claims[0][1]]) + .collect::>(); + + let evals_L_row_col = claims_inner + .into_iter() + .map(|claims| { + // [L_row, L_col] + [claims[0][0], claims[0][1]] + }) + .collect::>(); + + let (evals_mem_oracle, evals_mem_ts): (Vec<_>, Vec<_>) = claims_mem + .into_iter() + .map(|claims| { + ( + // [t_plus_r_inv_row, w_plus_r_inv_row, t_plus_r_inv_col, w_plus_r_inv_col] + [claims[0][0], claims[0][1], claims[1][0], claims[1][1]], + // [ts_row, ts_col] + [claims[0][2], claims[1][2]], + ) + }) + .unzip(); + + let evals_W = claims_witness + .into_iter() + .map(|claims| claims[0][0]) + .collect::>(); + + let (evals_Cz_E, evals_mem_val_row_col): (Vec<_>, Vec<_>) = zip_with!( + iter, + (polys_Az_Bz_Cz, polys_E, pk.S_repr), + |ABCzs, poly_E, s_repr| { + let [_, _, Cz] = ABCzs; + let log_Ni = s_repr.N.log_2(); + let (_, rand_sc) = rand_sc.split_at(num_rounds_sc - log_Ni); + let rand_sc_evals = EqPolynomial::evals_from_points(rand_sc); + let e = [ + Cz, + poly_E, + &s_repr.val_A, + &s_repr.val_B, + &s_repr.val_C, + &s_repr.row, + &s_repr.col, + ] + .into_iter() + .map(|p| { + // Manually compute evaluation to avoid recomputing rand_sc_evals + zip_with!(par_iter, (p, rand_sc_evals), |p, eq| *p * eq).sum() + }) + .collect::>(); + ([e[0], e[1]], [e[2], e[3], e[4], e[5], e[6]]) + } + ) + .unzip(); + + let evals_Az_Bz_Cz_W_E = zip_with!( + (evals_Az_Bz.into_iter(), evals_Cz_E.into_iter(), evals_W), + |Az_Bz, Cz_E, W| { + let [Az, Bz] = Az_Bz; + let [Cz, E] = Cz_E; + [Az, Bz, Cz, W, E] + } + ) + .collect::>(); + + // [val_A, val_B, val_C, row, col, ts_row, ts_col] + let evals_mem_preprocessed = zip_with!( + (evals_mem_val_row_col.into_iter(), evals_mem_ts), + |eval_mem_val_row_col, eval_mem_ts| { + let [val_A, val_B, val_C, row, col] = eval_mem_val_row_col; + let [ts_row, ts_col] = eval_mem_ts; + [val_A, val_B, val_C, row, col, ts_row, ts_col] + } + ) + .collect::>(); + ( + evals_Az_Bz_Cz_W_E, + evals_L_row_col, + evals_mem_oracle, + evals_mem_preprocessed, + ) + }; + + let evals_vec = zip_with!( + iter, + ( + evals_Az_Bz_Cz_W_E, + evals_L_row_col, + evals_mem_oracle, + evals_mem_preprocessed + ), + |Az_Bz_Cz_W_E, L_row_col, mem_oracles, mem_preprocessed| { + chain![Az_Bz_Cz_W_E, L_row_col, mem_oracles, mem_preprocessed] + .cloned() + .collect::>() + } + ) + .collect::>(); + + let comms_vec = zip_with!( + iter, + ( + comms_Az_Bz_Cz, + comms_W_E, + comms_L_row_col, + comms_mem_oracles, + pk.S_comm + ), + |Az_Bz_Cz, comms_W_E, L_row_col, mem_oracles, S_comm| { + chain![ + Az_Bz_Cz, + comms_W_E, + L_row_col, + mem_oracles, + [ + &S_comm.comm_val_A, + &S_comm.comm_val_B, + &S_comm.comm_val_C, + &S_comm.comm_row, + &S_comm.comm_col, + &S_comm.comm_ts_row, + &S_comm.comm_ts_col, + ] + ] + } + ) + .flatten() + .cloned() + .collect::>(); + + let w_vec = zip_with!( + ( + polys_Az_Bz_Cz.into_iter(), + polys_W.into_iter(), + polys_E.into_iter(), + polys_L_row_col.into_iter(), + polys_mem_oracles.into_iter(), + pk.S_repr.iter() + ), + |Az_Bz_Cz, W, E, L_row_col, mem_oracles, S_repr| { + chain![ + Az_Bz_Cz, + [W, E], + L_row_col, + mem_oracles, + [ + S_repr.val_A.clone(), + S_repr.val_B.clone(), + S_repr.val_C.clone(), + S_repr.row.clone(), + S_repr.col.clone(), + S_repr.ts_row.clone(), + S_repr.ts_col.clone(), + ] + ] + } + ) + .flatten() + .map(|p| PolyEvalWitness:: { p }) + .collect::>(); + + evals_vec.iter().for_each(|evals| { + transcript.absorb(b"e", &evals.as_slice()); // comm_vec is already in the transcript + }); + let evals_vec = evals_vec.into_iter().flatten().collect::>(); + + let c = transcript.squeeze(b"c")?; + + // Compute number of variables for each polynomial + let num_vars_u = w_vec.iter().map(|w| w.p.len().log_2()).collect::>(); + let u_batch = + PolyEvalInstance::::batch_diff_size(&comms_vec, &evals_vec, &num_vars_u, rand_sc, c); + let w_batch = PolyEvalWitness::::batch_diff_size(w_vec, c); + + let eval_arg = EE::prove( + ck, + &pk.pk_ee, + &mut transcript, + &u_batch.c, + &w_batch.p, + &u_batch.x, + &u_batch.e, + )?; + + let comms_Az_Bz_Cz = comms_Az_Bz_Cz + .into_iter() + .map(|comms| comms.map(|comm| comm.compress())) + .collect(); + let comms_L_row_col = comms_L_row_col + .into_iter() + .map(|comms| comms.map(|comm| comm.compress())) + .collect(); + let comms_mem_oracles = comms_mem_oracles + .into_iter() + .map(|comms| comms.map(|comm| comm.compress())) + .collect(); + + Ok(BatchedRelaxedR1CSSNARK { + comms_Az_Bz_Cz, + comms_L_row_col, + comms_mem_oracles, + evals_Az_Bz_Cz_at_tau, + sc, + evals_Az_Bz_Cz_W_E, + evals_L_row_col, + evals_mem_oracle, + evals_mem_preprocessed, + eval_arg, + }) + } + + fn verify(&self, vk: &Self::VerifierKey, U: &[RelaxedR1CSInstance]) -> Result<(), NovaError> { + let num_instances = U.len(); + let num_claims_per_instance = 10; + + // number of rounds of sum-check + let num_rounds = vk.S_comm.iter().map(|s| s.N.log_2()).collect::>(); + let num_rounds_max = *num_rounds.iter().max().unwrap(); + + let mut transcript = E::TE::new(b"BatchedRelaxedR1CSSNARK"); + + transcript.absorb(b"vk", &vk.digest()); + if num_instances > 1 { + let num_instances_field = E::Scalar::from(num_instances as u64); + transcript.absorb(b"n", &num_instances_field); + } + U.iter().for_each(|u| { + transcript.absorb(b"U", u); + }); + + // Decompress commitments + let comms_Az_Bz_Cz = self + .comms_Az_Bz_Cz + .iter() + .map(|comms| { + comms + .iter() + .map(Commitment::::decompress) + .collect::, _>>() + }) + .collect::, _>>()?; + + let comms_L_row_col = self + .comms_L_row_col + .iter() + .map(|comms| { + comms + .iter() + .map(Commitment::::decompress) + .collect::, _>>() + }) + .collect::, _>>()?; + + let comms_mem_oracles = self + .comms_mem_oracles + .iter() + .map(|comms| { + comms + .iter() + .map(Commitment::::decompress) + .collect::, _>>() + }) + .collect::, _>>()?; + + // Add commitments [Az, Bz, Cz] to the transcript + comms_Az_Bz_Cz + .iter() + .for_each(|comms| transcript.absorb(b"c", &comms.as_slice())); + + let tau = transcript.squeeze(b"t")?; + let tau_coords = PowPolynomial::new(&tau, num_rounds_max).coordinates(); + + // absorb the claimed evaluations into the transcript + self.evals_Az_Bz_Cz_at_tau.iter().for_each(|evals| { + transcript.absorb(b"e", &evals.as_slice()); + }); + + // absorb commitments to L_row and L_col in the transcript + comms_L_row_col.iter().for_each(|comms| { + transcript.absorb(b"e", &comms.as_slice()); + }); + + // Batch at tau for each instance + let c = transcript.squeeze(b"c")?; + + // Compute eval_Mz = eval_Az_at_tau + c * eval_Bz_at_tau + c^2 * eval_Cz_at_tau + let evals_Mz: Vec<_> = zip_with!( + iter, + (comms_Az_Bz_Cz, self.evals_Az_Bz_Cz_at_tau), + |comm_Az_Bz_Cz, evals_Az_Bz_Cz_at_tau| { + let u = PolyEvalInstance::::batch( + comm_Az_Bz_Cz.as_slice(), + &tau_coords, + evals_Az_Bz_Cz_at_tau.as_slice(), + &c, + ); + u.e + } + ) + .collect(); + + let gamma = transcript.squeeze(b"g")?; + let r = transcript.squeeze(b"r")?; + + comms_mem_oracles.iter().for_each(|comms| { + transcript.absorb(b"l", &comms.as_slice()); + }); + + let rho = transcript.squeeze(b"r")?; + + let s = transcript.squeeze(b"r")?; + let s_powers = powers::(&s, num_instances * num_claims_per_instance); + + let (claim_sc_final, rand_sc) = { + // Gather all claims into a single vector + let claims = evals_Mz + .iter() + .flat_map(|&eval_Mz| { + let mut claims = vec![E::Scalar::ZERO; num_claims_per_instance]; + claims[7] = eval_Mz; + claims[8] = eval_Mz; + claims.into_iter() + }) + .collect::>(); + + // Number of rounds for each claim + let num_rounds_by_claim = num_rounds + .iter() + .flat_map(|num_rounds_i| vec![*num_rounds_i; num_claims_per_instance].into_iter()) + .collect::>(); + + self + .sc + .verify_batch(&claims, &num_rounds_by_claim, &s_powers, 3, &mut transcript)? + }; + + // Truncated sumcheck randomness for each instance + let rand_sc_i = num_rounds + .iter() + .map(|num_rounds| rand_sc[(num_rounds_max - num_rounds)..].to_vec()) + .collect::>(); + + let claim_sc_final_expected = zip_with!( + ( + vk.num_vars.iter(), + rand_sc_i.iter(), + U.iter(), + self.evals_Az_Bz_Cz_W_E.iter().cloned(), + self.evals_L_row_col.iter().cloned(), + self.evals_mem_oracle.iter().cloned(), + self.evals_mem_preprocessed.iter().cloned() + ), + |num_vars, + rand_sc, + U, + evals_Az_Bz_Cz_W_E, + evals_L_row_col, + eval_mem_oracle, + eval_mem_preprocessed| { + let [Az, Bz, Cz, W, E] = evals_Az_Bz_Cz_W_E; + let [L_row, L_col] = evals_L_row_col; + let [t_plus_r_inv_row, w_plus_r_inv_row, t_plus_r_inv_col, w_plus_r_inv_col] = + eval_mem_oracle; + let [val_A, val_B, val_C, row, col, ts_row, ts_col] = eval_mem_preprocessed; + + let num_rounds_i = rand_sc.len(); + let num_vars_log = num_vars.log_2(); + + let eq_rho = { + let rho_coords = PowPolynomial::new(&rho, num_rounds_i).coordinates(); + EqPolynomial::new(rho_coords).evaluate(rand_sc) + }; + + let (eq_tau, eq_masked_tau) = { + let tau_coords = PowPolynomial::new(&tau, num_rounds_i).coordinates(); + let eq_tau = EqPolynomial::new(tau_coords); + + let eq_tau_at_rand = eq_tau.evaluate(rand_sc); + let eq_masked_tau = MaskedEqPolynomial::new(&eq_tau, num_vars_log).evaluate(rand_sc); + + (eq_tau_at_rand, eq_masked_tau) + }; + + // Evaluate identity polynomial + let id = IdentityPolynomial::new(num_rounds_i).evaluate(rand_sc); + + let Z = { + // rand_sc was padded, so we now remove the padding + let (factor, rand_sc_unpad) = { + let l = num_rounds_i - (num_vars_log + 1); + + let (rand_sc_lo, rand_sc_hi) = rand_sc.split_at(l); + + let factor = rand_sc_lo + .iter() + .fold(E::Scalar::ONE, |acc, r_p| acc * (E::Scalar::ONE - r_p)); + + (factor, rand_sc_hi) + }; + + let X = { + // constant term + let mut poly_X = vec![(0, U.u)]; + //remaining inputs + poly_X.extend( + (0..U.X.len()) + .map(|i| (i + 1, U.X[i])) + .collect::>(), + ); + SparsePolynomial::new(num_vars_log, poly_X).evaluate(&rand_sc_unpad[1..]) + }; + + // W was evaluated as if it was padded to logNi variables, + // so we don't multiply it by (1-rand_sc_unpad[0]) + W + factor * rand_sc_unpad[0] * X + }; + + let t_plus_r_row = { + let addr_row = id; + let val_row = eq_tau; + let t = addr_row + gamma * val_row; + t + r + }; + + let w_plus_r_row = { + let addr_row = row; + let val_row = L_row; + let w = addr_row + gamma * val_row; + w + r + }; + + let t_plus_r_col = { + let addr_col = id; + let val_col = Z; + let t = addr_col + gamma * val_col; + t + r + }; + + let w_plus_r_col = { + let addr_col = col; + let val_col = L_col; + let w = addr_col + gamma * val_col; + w + r + }; + + let claims_mem = [ + t_plus_r_inv_row - w_plus_r_inv_row, + t_plus_r_inv_col - w_plus_r_inv_col, + eq_rho * (t_plus_r_inv_row * t_plus_r_row - ts_row), + eq_rho * (w_plus_r_inv_row * w_plus_r_row - E::Scalar::ONE), + eq_rho * (t_plus_r_inv_col * t_plus_r_col - ts_col), + eq_rho * (w_plus_r_inv_col * w_plus_r_col - E::Scalar::ONE), + ]; + + let claims_outer = [ + eq_tau * (Az * Bz - U.u * Cz - E), + eq_tau * (Az + c * Bz + c * c * Cz), + ]; + let claims_inner = [L_row * L_col * (val_A + c * val_B + c * c * val_C)]; + + let claims_witness = [eq_masked_tau * W]; + + chain![claims_mem, claims_outer, claims_inner, claims_witness] + } + ) + .flatten() + .zip_eq(s_powers) + .fold(E::Scalar::ZERO, |acc, (claim, s)| acc + s * claim); + + if claim_sc_final_expected != claim_sc_final { + return Err(NovaError::InvalidSumcheckProof); + } + + let evals_vec = zip_with!( + iter, + ( + self.evals_Az_Bz_Cz_W_E, + self.evals_L_row_col, + self.evals_mem_oracle, + self.evals_mem_preprocessed + ), + |Az_Bz_Cz_W_E, L_row_col, mem_oracles, mem_preprocessed| { + chain![Az_Bz_Cz_W_E, L_row_col, mem_oracles, mem_preprocessed] + .cloned() + .collect::>() + } + ) + .collect::>(); + + // Add all Sumcheck evaluations to the transcript + evals_vec.iter().for_each(|evals| { + transcript.absorb(b"e", &evals.as_slice()); // comm_vec is already in the transcript + }); + + let c = transcript.squeeze(b"c")?; + + // Compute batched polynomial evaluation instance at rand_sc + let u = { + let num_evals = evals_vec[0].len(); + + let evals_vec = evals_vec.into_iter().flatten().collect::>(); + + let num_vars = num_rounds + .iter() + .flat_map(|num_rounds| vec![*num_rounds; num_evals].into_iter()) + .collect::>(); + + let comms_vec = zip_with!( + ( + comms_Az_Bz_Cz.into_iter(), + U.iter(), + comms_L_row_col.into_iter(), + comms_mem_oracles.into_iter(), + vk.S_comm.iter() + ), + |Az_Bz_Cz, U, L_row_col, mem_oracles, S_comm| { + chain![ + Az_Bz_Cz, + [U.comm_W, U.comm_E], + L_row_col, + mem_oracles, + [ + S_comm.comm_val_A, + S_comm.comm_val_B, + S_comm.comm_val_C, + S_comm.comm_row, + S_comm.comm_col, + S_comm.comm_ts_row, + S_comm.comm_ts_col, + ] + ] + } + ) + .flatten() + .collect::>(); + + PolyEvalInstance::::batch_diff_size(&comms_vec, &evals_vec, &num_vars, rand_sc, c) + }; + + // verify + EE::verify(&vk.vk_ee, &mut transcript, &u.c, &u.x, &u.e, &self.eval_arg)?; + + Ok(()) + } +} + +impl> BatchedRelaxedR1CSSNARK +{ + /// Runs the batched Sumcheck protocol for the claims of multiple instance of possibly different sizes. + /// + /// # Details + /// + /// In order to avoid padding all polynomials to the same maximum size, we adopt the following strategy. + /// + /// Let n be the number of variables for the largest instance, + /// and let m be the number of variables for a shorter one. + /// Let P(X_{0},...,X_{m-1}) be one of the MLEs of the short instance, which has been committed to + /// by taking the MSM of its evaluations with the first 2^m basis points of the commitment key. + /// + /// This Sumcheck prover will interpret it as the polynomial + /// P'(X_{0},...,X_{n-1}) = P(X_{n-m},...,X_{n-1}), + /// whose MLE evaluations over {0,1}^m is equal to 2^{n-m} repetitions of the evaluations of P. + /// + /// In order to account for these "imagined" repetitions, the initial claims for this short instances + /// are scaled by 2^{n-m}. + /// + /// For the first n-m rounds, the univariate polynomials relating to this shorter claim will be constant, + /// and equal to the initial claims, scaled by 2^{n-m-i}, where i is the round number. + /// By definition, P' does not depend on X_i, so binding P' to r_i has no effect on the evaluations. + /// The Sumcheck prover will then interpret the polynomial P' as having half as many repetitions + /// in the next round. + /// + /// When we get to round n-m, the Sumcheck proceeds as usual since the polynomials are the expected size + /// for the round. + /// + /// Note that at the end of the protocol, the prover returns the evaluation + /// u' = P'(r_{0},...,r_{n-1}) = P(r_{n-m},...,r_{n-1}) + /// However, the polynomial we actually committed to over {0,1}^n is + /// P''(X_{0},...,X_{n-1}) = L_0(X_{0},...,X_{n-m-1}) * P(X_{n-m},...,X_{n-1}) + /// The SNARK prover/verifier will need to rescale the evaluation by the first Lagrange polynomial + /// u'' = L_0(r_{0},...,r_{n-m-1}) * u' + /// in order batch all evaluations with a single PCS call. + fn prove_helper( + num_rounds: usize, + mut mem: Vec, + mut outer: Vec, + mut inner: Vec, + mut witness: Vec, + transcript: &mut E::TE, + ) -> Result< + ( + SumcheckProof, + Vec, + Vec>>, + Vec>>, + Vec>>, + Vec>>, + ), + NovaError, + > + where + T1: SumcheckEngine, + T2: SumcheckEngine, + T3: SumcheckEngine, + T4: SumcheckEngine, + { + // sanity checks + let num_instances = mem.len(); + assert_eq!(outer.len(), num_instances); + assert_eq!(inner.len(), num_instances); + assert_eq!(witness.len(), num_instances); + + mem.iter_mut().for_each(|inst| { + assert!(inst.size().is_power_of_two()); + }); + outer.iter().for_each(|inst| { + assert!(inst.size().is_power_of_two()); + }); + inner.iter().for_each(|inst| { + assert!(inst.size().is_power_of_two()); + }); + witness.iter().for_each(|inst| { + assert!(inst.size().is_power_of_two()); + }); + + let degree = mem[0].degree(); + assert!(mem.iter().all(|inst| inst.degree() == degree)); + assert!(outer.iter().all(|inst| inst.degree() == degree)); + assert!(inner.iter().all(|inst| inst.degree() == degree)); + assert!(witness.iter().all(|inst| inst.degree() == degree)); + + // Collect all claims from the instances. If the instances is defined over `m` variables, + // which is less that the total number of rounds `n`, + // the individual claims σ are scaled by 2^{n-m}. + let claims = zip_with!( + iter, + (mem, outer, inner, witness), + |mem, outer, inner, witness| { + Self::scaled_claims(mem, num_rounds) + .into_iter() + .chain(Self::scaled_claims(outer, num_rounds)) + .chain(Self::scaled_claims(inner, num_rounds)) + .chain(Self::scaled_claims(witness, num_rounds)) + } + ) + .flatten() + .collect::>(); + + // Sample a challenge for the random linear combination of all scaled claims + let s = transcript.squeeze(b"r")?; + let coeffs = powers::(&s, claims.len()); + + // At the start of each round, the running claim is equal to the random linear combination + // of the Sumcheck claims, evaluated over the bound polynomials. + // Initially, it is equal to the random linear combination of the scaled input claims. + let mut running_claim = zip_with!(iter, (claims, coeffs), |c_1, c_2| *c_1 * c_2).sum(); + + // Keep track of the verifier challenges r, and the univariate polynomials sent by the prover + // in each round + let mut r: Vec = Vec::new(); + let mut cubic_polys: Vec> = Vec::new(); + + for i in 0..num_rounds { + // At the start of round i, there input polynomials are defined over at most n-i variables. + let remaining_variables = num_rounds - i; + + // For each claim j, compute the evaluations of its univariate polynomial S_j(X_i) + // at X = 0, 2, 3. The polynomial is such that S_{j-1}(r_{j-1}) = S_j(0) + S_j(1). + // If the number of variable m of the claim is m < n-i, then the polynomial is + // constants and equal to the initial claim σ_j scaled by 2^{n-m-i-1}. + let evals = zip_with!( + par_iter, + (mem, outer, inner, witness), + |mem, outer, inner, witness| { + let ((evals_mem, evals_outer), (evals_inner, evals_witness)) = rayon::join( + || { + rayon::join( + || Self::get_evals(mem, remaining_variables), + || Self::get_evals(outer, remaining_variables), + ) + }, + || { + rayon::join( + || Self::get_evals(inner, remaining_variables), + || Self::get_evals(witness, remaining_variables), + ) + }, + ); + evals_mem + .into_par_iter() + .chain(evals_outer.into_par_iter()) + .chain(evals_inner.into_par_iter()) + .chain(evals_witness.into_par_iter()) + } + ) + .flatten() + .collect::>(); + + assert_eq!(evals.len(), claims.len()); + + // Random linear combination of the univariate evaluations at X_i = 0, 2, 3 + let evals_combined_0 = (0..evals.len()).map(|i| evals[i][0] * coeffs[i]).sum(); + let evals_combined_2 = (0..evals.len()).map(|i| evals[i][1] * coeffs[i]).sum(); + let evals_combined_3 = (0..evals.len()).map(|i| evals[i][2] * coeffs[i]).sum(); + + let evals = vec![ + evals_combined_0, + running_claim - evals_combined_0, + evals_combined_2, + evals_combined_3, + ]; + // Coefficient representation of S(X_i) + let poly = UniPoly::from_evals(&evals); + + // append the prover's message to the transcript + transcript.absorb(b"p", &poly); + + // derive the verifier's challenge for the next round + let r_i = transcript.squeeze(b"c")?; + r.push(r_i); + + // Bind the variable X_i of polynomials across all claims to r_i. + // If the claim is defined over m variables and m < n-i, then + // binding has no effect on the polynomial. + zip_with_for_each!( + par_iter_mut, + (mem, outer, inner, witness), + |mem, outer, inner, witness| { + rayon::join( + || { + rayon::join( + || Self::bind(mem, remaining_variables, &r_i), + || Self::bind(outer, remaining_variables, &r_i), + ) + }, + || { + rayon::join( + || Self::bind(inner, remaining_variables, &r_i), + || Self::bind(witness, remaining_variables, &r_i), + ) + }, + ); + } + ); + + running_claim = poly.evaluate(&r_i); + cubic_polys.push(poly.compress()); + } + + // Collect evaluations at (r_{n-m}, ..., r_{n-1}) of polynomials over all claims, + // where m is the initial number of variables the individual claims are defined over. + let claims_outer = outer.into_iter().map(|inst| inst.final_claims()).collect(); + let claims_inner = inner.into_iter().map(|inst| inst.final_claims()).collect(); + let claims_mem = mem.into_iter().map(|inst| inst.final_claims()).collect(); + let claims_witness = witness + .into_iter() + .map(|inst| inst.final_claims()) + .collect(); + + Ok(( + SumcheckProof::new(cubic_polys), + r, + claims_outer, + claims_inner, + claims_mem, + claims_witness, + )) + } + + /// In round i, computes the evaluations at X_i = 0, 2, 3 of the univariate polynomials S(X_i) + /// for each claim in the instance. + /// Let `n` be the total number of Sumcheck rounds, and assume the instance is defined over `m` variables. + /// We define `remaining_variables` as n-i. + /// If m < n-i, then the polynomials in the instance are not defined over X_i, so the univariate + /// polynomial is constant and equal to 2^{n-m-i-1}*σ, where σ is the initial claim. + fn get_evals>(inst: &T, remaining_variables: usize) -> Vec> { + let num_instance_variables = inst.size().log_2(); // m + if num_instance_variables < remaining_variables { + let deg = inst.degree(); + + // The evaluations at X_i = 0, 2, 3 are all equal to the scaled claim + Self::scaled_claims(inst, remaining_variables - 1) + .into_iter() + .map(|scaled_claim| vec![scaled_claim; deg]) + .collect() + } else { + inst.evaluation_points() + } + } + + /// In round i after receiving challenge r_i, we partially evaluate all polynomials in the instance + /// at X_i = r_i. If the instance is defined over m variables m which is less than n-i, then + /// the polynomials do not depend on X_i, so binding them to r_i has no effect. + fn bind>(inst: &mut T, remaining_variables: usize, r: &E::Scalar) { + let num_instance_variables = inst.size().log_2(); // m + if remaining_variables <= num_instance_variables { + inst.bound(r) + } + } + + /// Given an instance defined over m variables, the sum over n = `remaining_variables` is equal + /// to the initial claim scaled by 2^{n-m}, when m ≤ n. + fn scaled_claims>(inst: &T, remaining_variables: usize) -> Vec { + let num_instance_variables = inst.size().log_2(); // m + let num_repetitions = 1 << (remaining_variables - num_instance_variables); + let scaling = E::Scalar::from(num_repetitions as u64); + inst + .initial_claims() + .iter() + .map(|claim| scaling * claim) + .collect() + } +} diff --git a/src/spartan/direct.rs b/src/spartan/direct.rs index 046a5c4dd..6a7fb5094 100644 --- a/src/spartan/direct.rs +++ b/src/spartan/direct.rs @@ -108,8 +108,7 @@ impl, C: StepCircuit> DirectSN let mut cs: ShapeCS = ShapeCS::new(); let _ = circuit.synthesize(&mut cs); - - let (shape, ck) = cs.r1cs_shape(&*S::ck_floor()); + let (shape, ck) = cs.r1cs_shape_and_key(&*S::ck_floor()); let (pk, vk) = S::setup(&ck, &shape)?; diff --git a/src/spartan/macros.rs b/src/spartan/macros.rs new file mode 100644 index 000000000..9e757a2ed --- /dev/null +++ b/src/spartan/macros.rs @@ -0,0 +1,104 @@ +/// Macros to give syntactic sugar for zipWith pattern and variatns. +/// +/// ```ignore +/// use crate::spartan::zip_with; +/// use itertools::Itertools as _; // we use zip_eq to zip! +/// let v = vec![0, 1, 2]; +/// let w = vec![2, 3, 4]; +/// let y = vec![4, 5, 6]; +/// +/// // Using the `zip_with!` macro to zip three iterators together and apply a closure +/// // that sums the elements of each iterator. +/// let res = zip_with!((v.iter(), w.iter(), y.iter()), |a, b, c| a + b + c) +/// .collect::>(); +/// +/// println!("{:?}", res); // Output: [6, 9, 12] +/// ``` + +#[macro_export] +macro_rules! zip_with { + // no iterator projection specified: the macro assumes the arguments *are* iterators + // ```ignore + // zip_with!((iter1, iter2, iter3), |a, b, c| a + b + c) -> + // iter1.zip_eq(iter2.zip_eq(iter3)).map(|(a, (b, c))| a + b + c) + // ``` + // + // iterator projection specified: use it on each argument + // ```ignore + // zip_with!(par_iter, (vec1, vec2, vec3), |a, b, c| a + b + c) -> + // vec1.par_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).map(|(a, (b, c))| a + b + c) + // ```` + ($($f:ident,)? ($e:expr $(, $rest:expr)*), $($move:ident)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{ + $crate::zip_with!($($f,)? ($e $(, $rest)*), map, $($move)? |$($i),+| $($work)*) + }}; + // no iterator projection specified: the macro assumes the arguments *are* iterators + // optional zipping function specified as well: use it instead of map + // ```ignore + // zip_with!((iter1, iter2, iter3), for_each, |a, b, c| a + b + c) -> + // iter1.zip_eq(iter2.zip_eq(iter3)).for_each(|(a, (b, c))| a + b + c) + // ``` + // + // + // iterator projection specified: use it on each argument + // optional zipping function specified as well: use it instead of map + // ```ignore + // zip_with!(par_iter, (vec1, vec2, vec3), for_each, |a, b, c| a + b + c) -> + // vec1.part_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).for_each(|(a, (b, c))| a + b + c) + // ``` + ($($f:ident,)? ($e:expr $(, $rest:expr)*), $worker:ident, $($move:ident,)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{ + $crate::zip_all!($($f,)? ($e $(, $rest)*)) + .$worker($($move)? |$crate::nested_idents!($($i),+)| { + $($work)* + }) + }}; +} + +/// Like `zip_with` but use `for_each` instead of `map`. +#[macro_export] +macro_rules! zip_with_for_each { + // no iterator projection specified: the macro assumes the arguments *are* iterators + // ```ignore + // zip_with_for_each!((iter1, iter2, iter3), |a, b, c| a + b + c) -> + // iter1.zip_eq(iter2.zip_eq(iter3)).for_each(|(a, (b, c))| a + b + c) + // ``` + // + // iterator projection specified: use it on each argument + // ```ignore + // zip_with_for_each!(par_iter, (vec1, vec2, vec3), |a, b, c| a + b + c) -> + // vec1.par_iter().zip_eq(vec2.par_iter().zip_eq(vec3.par_iter())).for_each(|(a, (b, c))| a + b + c) + // ```` + ($($f:ident,)? ($e:expr $(, $rest:expr)*), $($move:ident)? |$($i:ident),+ $(,)?| $($work:tt)*) => {{ + $crate::zip_with!($($f,)? ($e $(, $rest)*), for_each, $($move)? |$($i),+| $($work)*) + }}; +} + +// Foldright-like nesting for idents (a, b, c) -> (a, (b, c)) +#[doc(hidden)] +#[allow(unused_macro_rules)] +#[macro_export] +macro_rules! nested_idents { + ($a:ident, $b:ident) => { + ($a, $b) + }; + ($first:ident, $($rest:ident),+) => { + ($first, $crate::nested_idents!($($rest),+)) + }; +} + +// Fold-right like zipping, with an optional function `f` to apply to each argument +#[doc(hidden)] +#[macro_export] +macro_rules! zip_all { + (($e:expr,)) => { + $e + }; + ($f:ident, ($e:expr,)) => { + $e.$f() + }; + ($f:ident, ($first:expr, $second:expr $(, $rest:expr)*)) => { + ($first.$f().zip($crate::zip_all!($f, ($second, $( $rest),*)))) + }; + (($first:expr, $second:expr $(, $rest:expr)*)) => { + ($first.zip($crate::zip_all!(($second, $( $rest),*)))) + }; +} diff --git a/src/spartan/mod.rs b/src/spartan/mod.rs index 3e7c514e8..382da8466 100644 --- a/src/spartan/mod.rs +++ b/src/spartan/mod.rs @@ -5,21 +5,31 @@ //! We also provide direct.rs that allows proving a step circuit directly with either of the two SNARKs. //! //! In polynomial.rs we also provide foundational types and functions for manipulating multilinear polynomials. + +pub mod batched; +pub mod batched_ppsnark; pub mod direct; +#[macro_use] +mod macros; pub(crate) mod math; pub mod polys; pub mod ppsnark; pub mod snark; mod sumcheck; -use crate::{traits::Engine, Commitment}; +use crate::{ + r1cs::{R1CSShape, SparseMatrix}, + traits::Engine, + Commitment, +}; use ff::Field; use polys::multilinear::SparsePolynomial; use rayon::{iter::IntoParallelRefIterator, prelude::*}; +// Creates a vector of the first `n` powers of `s`. fn powers(s: &E::Scalar, n: usize) -> Vec { assert!(n >= 1); - let mut powers = Vec::new(); + let mut powers = Vec::with_capacity(n); powers.push(E::Scalar::ONE); for i in 1..n { powers.push(powers[i - 1] * s); @@ -33,30 +43,56 @@ pub struct PolyEvalWitness { } impl PolyEvalWitness { - fn pad(mut W: Vec>) -> Vec> { - // determine the maximum size - if let Some(n) = W.iter().map(|w| w.p.len()).max() { - W.iter_mut().for_each(|w| { - w.p.resize(n, E::Scalar::ZERO); - }); - W - } else { - Vec::new() - } - } + /// Given [Pᵢ] and s, compute P = ∑ᵢ sⁱ⋅Pᵢ + /// + /// # Details + /// + /// We allow the input polynomials to have different sizes, and interpret smaller ones as + /// being padded with 0 to the maximum size of all polynomials. + fn batch_diff_size(W: Vec>, s: E::Scalar) -> PolyEvalWitness { + let powers = powers::(&s, W.len()); + + let size_max = W.iter().map(|w| w.p.len()).max().unwrap(); + // Scale the input polynomials by the power of s + let p = W + .into_par_iter() + .zip_eq(powers.par_iter()) + .map(|(mut w, s)| { + if *s != E::Scalar::ONE { + w.p.par_iter_mut().for_each(|e| *e *= s); + } + w.p + }) + .reduce( + || vec![E::Scalar::ZERO; size_max], + |left, right| { + // Sum into the largest polynomial + let (mut big, small) = if left.len() > right.len() { + (left, right) + } else { + (right, left) + }; + + #[allow(clippy::disallowed_methods)] + big + .par_iter_mut() + .zip(small.par_iter()) + .for_each(|(b, s)| *b += s); + + big + }, + ); - fn weighted_sum(W: &[PolyEvalWitness], s: &[E::Scalar]) -> PolyEvalWitness { - assert_eq!(W.len(), s.len()); - let mut p = vec![E::Scalar::ZERO; W[0].p.len()]; - for i in 0..W.len() { - for j in 0..W[i].p.len() { - p[j] += W[i].p[j] * s[i] - } - } PolyEvalWitness { p } } - // This method panics unless all vectors in p_vec are of the same length + /// Given a set of polynomials \[Pᵢ\] and a scalar `s`, this method computes the weighted sum + /// of the polynomials, where each polynomial Pᵢ is scaled by sⁱ. The method handles polynomials + /// of different sizes by padding smaller ones with zeroes up to the size of the largest polynomial. + /// + /// # Panics + /// + /// This method panics if the polynomials in `p_vec` are not all of the same length. fn batch(p_vec: &[&Vec], s: &E::Scalar) -> PolyEvalWitness { p_vec .iter() @@ -64,20 +100,17 @@ impl PolyEvalWitness { let powers_of_s = powers::(s, p_vec.len()); - let p = p_vec - .par_iter() - .zip(powers_of_s.par_iter()) - .map(|(v, &weight)| { - // compute the weighted sum for each vector - v.iter().map(|&x| x * weight).collect::>() - }) - .reduce( - || vec![E::Scalar::ZERO; p_vec[0].len()], - |acc, v| { - // perform vector addition to combine the weighted vectors - acc.into_iter().zip(v).map(|(x, y)| x + y).collect() - }, - ); + let p = zip_with!(par_iter, (p_vec, powers_of_s), |v, weight| { + // compute the weighted sum for each vector + v.iter().map(|&x| x * *weight).collect::>() + }) + .reduce( + || vec![E::Scalar::ZERO; p_vec[0].len()], + |acc, v| { + // perform vector addition to combine the weighted vectors + acc.into_iter().zip(v).map(|(x, y)| x + y).collect() + }, + ); PolyEvalWitness { p } } @@ -91,18 +124,47 @@ pub struct PolyEvalInstance { } impl PolyEvalInstance { - fn pad(U: Vec>) -> Vec> { - // determine the maximum size - if let Some(ell) = U.iter().map(|u| u.x.len()).max() { - U.into_iter() - .map(|mut u| { - let mut x = vec![E::Scalar::ZERO; ell - u.x.len()]; - x.append(&mut u.x); - PolyEvalInstance { x, ..u } - }) - .collect() - } else { - Vec::new() + fn batch_diff_size( + c_vec: &[Commitment], + e_vec: &[E::Scalar], + num_vars: &[usize], + x: Vec, + s: E::Scalar, + ) -> PolyEvalInstance { + let num_instances = num_vars.len(); + assert_eq!(c_vec.len(), num_instances); + assert_eq!(e_vec.len(), num_instances); + + let num_vars_max = x.len(); + let powers: Vec = powers::(&s, num_instances); + // Rescale evaluations by the first Lagrange polynomial, + // so that we can check its evaluation against x + let evals_scaled = zip_with!(iter, (e_vec, num_vars), |eval, num_rounds| { + // x_lo = [ x[0] , ..., x[n-nᵢ-1] ] + // x_hi = [ x[n-nᵢ], ..., x[n] ] + let (r_lo, _r_hi) = x.split_at(num_vars_max - num_rounds); + // Compute L₀(x_lo) + let lagrange_eval = r_lo + .iter() + .map(|r| E::Scalar::ONE - r) + .product::(); + + // vᵢ = L₀(x_lo)⋅Pᵢ(x_hi) + lagrange_eval * eval + }) + .collect::>(); + + // C = ∑ᵢ γⁱ⋅Cᵢ + let comm_joint = zip_with!(iter, (c_vec, powers), |c, g_i| *c * *g_i) + .fold(Commitment::::default(), |acc, item| acc + item); + + // v = ∑ᵢ γⁱ⋅vᵢ + let eval_joint = zip_with!((evals_scaled.into_iter(), powers.iter()), |e, g_i| e * g_i).sum(); + + PolyEvalInstance { + c: comm_joint, + x, + e: eval_joint, } } @@ -112,16 +174,14 @@ impl PolyEvalInstance { e_vec: &[E::Scalar], s: &E::Scalar, ) -> PolyEvalInstance { - let powers_of_s = powers::(s, c_vec.len()); - let e = e_vec - .par_iter() - .zip(powers_of_s.par_iter()) - .map(|(e, p)| *e * p) - .sum(); - let c = c_vec - .par_iter() - .zip(powers_of_s.par_iter()) - .map(|(c, p)| *c * *p) + let num_instances = c_vec.len(); + assert_eq!(e_vec.len(), num_instances); + + let powers_of_s = powers::(s, num_instances); + // Weighted sum of evaluations + let e = zip_with!(par_iter, (e_vec, powers_of_s), |e, p| *e * p).sum(); + // Weighted sum of commitments + let c = zip_with!(par_iter, (c_vec, powers_of_s), |c, p| *c * *p) .reduce(Commitment::::default, |acc, item| acc + item); PolyEvalInstance { @@ -131,3 +191,43 @@ impl PolyEvalInstance { } } } + +/// Bounds "row" variables of (A, B, C) matrices viewed as 2d multilinear polynomials +pub fn compute_eval_table_sparse( + S: &R1CSShape, + rx: &[E::Scalar], +) -> (Vec, Vec, Vec) { + assert_eq!(rx.len(), S.num_cons); + + let inner = |M: &SparseMatrix, M_evals: &mut Vec| { + for (row_idx, ptrs) in M.indptr.windows(2).enumerate() { + for (val, col_idx) in M.get_row_unchecked(ptrs.try_into().unwrap()) { + M_evals[*col_idx] += rx[row_idx] * val; + } + } + }; + + let (A_evals, (B_evals, C_evals)) = rayon::join( + || { + let mut A_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; + inner(&S.A, &mut A_evals); + A_evals + }, + || { + rayon::join( + || { + let mut B_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; + inner(&S.B, &mut B_evals); + B_evals + }, + || { + let mut C_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; + inner(&S.C, &mut C_evals); + C_evals + }, + ) + }, + ); + + (A_evals, B_evals, C_evals) +} diff --git a/src/spartan/polys/eq.rs b/src/spartan/polys/eq.rs index 22ef1a137..89a81dcfc 100644 --- a/src/spartan/polys/eq.rs +++ b/src/spartan/polys/eq.rs @@ -14,8 +14,9 @@ use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator, Parall /// This polynomial evaluates to 1 if every component $x_i$ equals its corresponding $e_i$, and 0 otherwise. /// /// For instance, for e = 6 (with a binary representation of 0b110), the vector r would be [1, 1, 0]. +#[derive(Debug)] pub struct EqPolynomial { - r: Vec, + pub(crate) r: Vec, } impl EqPolynomial { @@ -35,7 +36,7 @@ impl EqPolynomial { pub fn evaluate(&self, rx: &[Scalar]) -> Scalar { assert_eq!(self.r.len(), rx.len()); (0..rx.len()) - .map(|i| rx[i] * self.r[i] + (Scalar::ONE - rx[i]) * (Scalar::ONE - self.r[i])) + .map(|i| self.r[i] * rx[i] + (Scalar::ONE - self.r[i]) * (Scalar::ONE - rx[i])) .fold(Scalar::ONE, |acc, item| acc * item) } @@ -43,18 +44,26 @@ impl EqPolynomial { /// /// Returns a vector of Scalars, each corresponding to the polynomial evaluation at a specific point. pub fn evals(&self) -> Vec { - let ell = self.r.len(); + Self::evals_from_points(&self.r) + } + + /// Evaluates the `EqPolynomial` from the `2^|r|` points in its domain, without creating an intermediate polynomial + /// representation. + /// + /// Returns a vector of Scalars, each corresponding to the polynomial evaluation at a specific point. + pub fn evals_from_points(r: &[Scalar]) -> Vec { + let ell = r.len(); let mut evals: Vec = vec![Scalar::ZERO; (2_usize).pow(ell as u32)]; let mut size = 1; evals[0] = Scalar::ONE; - for r in self.r.iter().rev() { + for r in r.iter().rev() { let (evals_left, evals_right) = evals.split_at_mut(size); let (evals_right, _) = evals_right.split_at_mut(size); evals_left .par_iter_mut() - .zip(evals_right.par_iter_mut()) + .zip_eq(evals_right.par_iter_mut()) .for_each(|(x, y)| { *y = *x * r; *x -= &*y; @@ -67,6 +76,13 @@ impl EqPolynomial { } } +impl FromIterator for EqPolynomial { + fn from_iter>(iter: I) -> Self { + let r: Vec<_> = iter.into_iter().collect(); + EqPolynomial { r } + } +} + #[cfg(test)] mod tests { use crate::provider; diff --git a/src/spartan/polys/masked_eq.rs b/src/spartan/polys/masked_eq.rs new file mode 100644 index 000000000..3d6ac7507 --- /dev/null +++ b/src/spartan/polys/masked_eq.rs @@ -0,0 +1,150 @@ +//! `MaskedEqPolynomial`: Represents the `eq` polynomial over n variables, where the first 2^m entries are 0. + +use crate::spartan::polys::eq::EqPolynomial; +use ff::PrimeField; +use itertools::zip_eq; + +/// Represents the multilinear extension polynomial (MLE) of the equality polynomial $eqₘ(x,r)$ +/// over n variables, where the first 2^m evaluations are 0. +/// +/// The polynomial is defined by the formula: +/// eqₘ(x,r) = eq(x,r) - ( ∏_{0 ≤ i < n-m} (1−rᵢ)(1−xᵢ) )⋅( ∏_{n-m ≤ i < n} (1−rᵢ)(1−xᵢ) + rᵢ⋅xᵢ ) +#[derive(Debug)] +pub struct MaskedEqPolynomial<'a, Scalar: PrimeField> { + eq: &'a EqPolynomial, + num_masked_vars: usize, +} + +impl<'a, Scalar: PrimeField> MaskedEqPolynomial<'a, Scalar> { + /// Creates a new `MaskedEqPolynomial` from a vector of Scalars `r` of size n, with the number of + /// masked variables m = `num_masked_vars`. + pub const fn new(eq: &'a EqPolynomial, num_masked_vars: usize) -> Self { + MaskedEqPolynomial { + eq, + num_masked_vars, + } + } + + /// Evaluates the `MaskedEqPolynomial` at a given point `rx`. + /// + /// This function computes the value of the polynomial at the point specified by `rx`. + /// It expects `rx` to have the same length as the internal vector `r`. + /// + /// Panics if `rx` and `r` have different lengths. + pub fn evaluate(&self, rx: &[Scalar]) -> Scalar { + let r = &self.eq.r; + assert_eq!(r.len(), rx.len()); + let split_idx = r.len() - self.num_masked_vars; + + let (r_lo, r_hi) = r.split_at(split_idx); + let (rx_lo, rx_hi) = rx.split_at(split_idx); + let eq_lo = zip_eq(r_lo, rx_lo) + .map(|(r, rx)| *r * rx + (Scalar::ONE - r) * (Scalar::ONE - rx)) + .product::(); + let eq_hi = zip_eq(r_hi, rx_hi) + .map(|(r, rx)| *r * rx + (Scalar::ONE - r) * (Scalar::ONE - rx)) + .product::(); + let mask_lo = zip_eq(r_lo, rx_lo) + .map(|(r, rx)| (Scalar::ONE - r) * (Scalar::ONE - rx)) + .product::(); + + (eq_lo - mask_lo) * eq_hi + } + + /// Evaluates the `MaskedEqPolynomial` at all the `2^|r|` points in its domain. + /// + /// Returns a vector of Scalars, each corresponding to the polynomial evaluation at a specific point. + pub fn evals(&self) -> Vec { + Self::evals_from_points(&self.eq.r, self.num_masked_vars) + } + + /// Evaluates the `MaskedEqPolynomial` from the `2^|r|` points in its domain, without creating an intermediate polynomial + /// representation. + /// + /// Returns a vector of Scalars, each corresponding to the polynomial evaluation at a specific point. + pub fn evals_from_points(r: &[Scalar], num_masked_vars: usize) -> Vec { + let mut evals = EqPolynomial::evals_from_points(r); + + // replace the first 2^m evaluations with 0 + let num_masked_evals = 1 << num_masked_vars; + evals[..num_masked_evals] + .iter_mut() + .for_each(|e| *e = Scalar::ZERO); + + evals + } +} + +#[cfg(test)] +mod tests { + use crate::provider; + + use super::*; + use crate::spartan::polys::eq::EqPolynomial; + use pasta_curves::Fp; + use rand_chacha::ChaCha20Rng; + use rand_core::{CryptoRng, RngCore, SeedableRng}; + + fn test_masked_eq_polynomial_with( + num_vars: usize, + num_masked_vars: usize, + mut rng: &mut R, + ) { + let num_masked_evals = 1 << num_masked_vars; + + // random point + let r = std::iter::from_fn(|| Some(F::random(&mut rng))) + .take(num_vars) + .collect::>(); + // evaluation point + let rx = std::iter::from_fn(|| Some(F::random(&mut rng))) + .take(num_vars) + .collect::>(); + + let poly_eq = EqPolynomial::new(r); + let poly_eq_evals = poly_eq.evals(); + + let masked_eq_poly = MaskedEqPolynomial::new(&poly_eq, num_masked_vars); + let masked_eq_poly_evals = masked_eq_poly.evals(); + + // ensure the first 2^m entries are 0 + assert_eq!( + masked_eq_poly_evals[..num_masked_evals], + vec![F::ZERO; num_masked_evals] + ); + // ensure the remaining evaluations match eq(r) + assert_eq!( + masked_eq_poly_evals[num_masked_evals..], + poly_eq_evals[num_masked_evals..] + ); + + // compute the evaluation at rx succinctly + let masked_eq_eval = masked_eq_poly.evaluate(&rx); + + // compute the evaluation as a MLE + let rx_evals = EqPolynomial::evals_from_points(&rx); + let expected_masked_eq_eval = zip_eq(rx_evals, masked_eq_poly_evals) + .map(|(rx, r)| rx * r) + .sum(); + + assert_eq!(masked_eq_eval, expected_masked_eq_eval); + } + + #[test] + fn test_masked_eq_polynomial() { + let mut rng = ChaCha20Rng::from_seed([0u8; 32]); + let num_vars = 5; + let num_masked_vars = 2; + test_masked_eq_polynomial_with::(num_vars, num_masked_vars, &mut rng); + test_masked_eq_polynomial_with::( + num_vars, + num_masked_vars, + &mut rng, + ); + test_masked_eq_polynomial_with::( + num_vars, + num_masked_vars, + &mut rng, + ); + } +} diff --git a/src/spartan/polys/mod.rs b/src/spartan/polys/mod.rs index d19d56f77..a1a192ef8 100644 --- a/src/spartan/polys/mod.rs +++ b/src/spartan/polys/mod.rs @@ -1,6 +1,7 @@ //! This module contains the definitions of polynomial types used in the Spartan SNARK. pub(crate) mod eq; pub(crate) mod identity; +pub(crate) mod masked_eq; pub(crate) mod multilinear; pub(crate) mod power; pub(crate) mod univariate; diff --git a/src/spartan/polys/multilinear.rs b/src/spartan/polys/multilinear.rs index 385d8a342..b54f88726 100644 --- a/src/spartan/polys/multilinear.rs +++ b/src/spartan/polys/multilinear.rs @@ -38,13 +38,12 @@ pub struct MultilinearPolynomial { impl MultilinearPolynomial { /// Creates a new `MultilinearPolynomial` from the given evaluations. /// + /// # Panics /// The number of evaluations must be a power of two. pub fn new(Z: Vec) -> Self { - assert_eq!(Z.len(), (2_usize).pow((Z.len() as f64).log2() as u32)); - MultilinearPolynomial { - num_vars: usize::try_from(Z.len().ilog2()).unwrap(), - Z, - } + let num_vars = Z.len().log_2(); + assert_eq!(Z.len(), 1 << num_vars); + MultilinearPolynomial { num_vars, Z } } /// Returns the number of variables in the multilinear polynomial @@ -57,17 +56,19 @@ impl MultilinearPolynomial { self.Z.len() } - /// Bounds the polynomial's top variable using the given scalar. + /// Binds the polynomial's top variable using the given scalar. /// /// This operation modifies the polynomial in-place. pub fn bind_poly_var_top(&mut self, r: &Scalar) { + assert!(self.num_vars > 0); + let n = self.len() / 2; let (left, right) = self.Z.split_at_mut(n); left .par_iter_mut() - .zip(right.par_iter()) + .zip_eq(right.par_iter()) .for_each(|(a, b)| { *a += *r * (*b - *a); }); @@ -83,23 +84,25 @@ impl MultilinearPolynomial { pub fn evaluate(&self, r: &[Scalar]) -> Scalar { // r must have a value for each variable assert_eq!(r.len(), self.get_num_vars()); - let chis = EqPolynomial::new(r.to_vec()).evals(); - assert_eq!(chis.len(), self.Z.len()); + let chis = EqPolynomial::evals_from_points(r); - (0..chis.len()) - .into_par_iter() - .map(|i| chis[i] * self.Z[i]) - .sum() + zip_with!( + (chis.into_par_iter(), self.Z.par_iter()), + |chi_i, Z_i| chi_i * Z_i + ) + .sum() } /// Evaluates the polynomial with the given evaluations and point. pub fn evaluate_with(Z: &[Scalar], r: &[Scalar]) -> Scalar { - EqPolynomial::new(r.to_vec()) - .evals() - .into_par_iter() - .zip(Z.into_par_iter()) - .map(|(a, b)| a * b) - .sum() + zip_with!( + ( + EqPolynomial::evals_from_points(r).into_par_iter(), + Z.par_iter() + ), + |a, b| a * b + ) + .sum() } } @@ -143,7 +146,7 @@ impl SparsePolynomial { chi_i } - // Takes O(n log n) + // Takes O(m log n) where m is the number of non-zero evaluations and n is the number of variables. pub fn evaluate(&self, r: &[Scalar]) -> Scalar { assert_eq!(self.num_vars, r.len()); @@ -167,12 +170,7 @@ impl Add for MultilinearPolynomial { return Err("The two polynomials must have the same number of variables"); } - let sum: Vec = self - .Z - .iter() - .zip(other.Z.iter()) - .map(|(a, b)| *a + *b) - .collect(); + let sum: Vec = zip_with!(into_iter, (self.Z, other.Z), |a, b| a + b).collect(); Ok(MultilinearPolynomial::new(sum)) } @@ -184,7 +182,7 @@ mod tests { use super::*; use rand_chacha::ChaCha20Rng; - use rand_core::{CryptoRng, RngCore, SeedableRng}; + use rand_core::{SeedableRng, RngCore, CryptoRng}; fn make_mlp(len: usize, value: F) -> MultilinearPolynomial { MultilinearPolynomial { @@ -264,7 +262,7 @@ mod tests { let num_evals = 4; let mut evals: Vec = Vec::with_capacity(num_evals); for _ in 0..num_evals { - evals.push(F::from_u128(8)); + evals.push(F::from(8)); } let dense_poly: MultilinearPolynomial = MultilinearPolynomial::new(evals.clone()); diff --git a/src/spartan/polys/power.rs b/src/spartan/polys/power.rs index 06721a23c..9350a229e 100644 --- a/src/spartan/polys/power.rs +++ b/src/spartan/polys/power.rs @@ -2,6 +2,7 @@ use crate::spartan::polys::eq::EqPolynomial; use ff::PrimeField; +use std::iter::successors; /// Represents the multilinear extension polynomial (MLE) of the equality polynomial $pow(x,t)$, denoted as $\tilde{pow}(x, t)$. /// @@ -10,7 +11,6 @@ use ff::PrimeField; /// \tilde{power}(x, t) = \prod_{i=1}^m(1 + (t^{2^i} - 1) * x_i) /// $$ pub struct PowPolynomial { - t_pow: Vec, eq: EqPolynomial, } @@ -18,18 +18,28 @@ impl PowPolynomial { /// Creates a new `PowPolynomial` from a Scalars `t`. pub fn new(t: &Scalar, ell: usize) -> Self { // t_pow = [t^{2^0}, t^{2^1}, ..., t^{2^{ell-1}}] - let mut t_pow = vec![Scalar::ONE; ell]; - t_pow[0] = *t; - for i in 1..ell { - t_pow[i] = t_pow[i - 1].square(); - } + let t_pow = Self::squares(t, ell); PowPolynomial { - t_pow: t_pow.clone(), eq: EqPolynomial::new(t_pow), } } + /// Create powers the following powers of `t`: + /// [t^{2^0}, t^{2^1}, ..., t^{2^{ell-1}}] + pub(in crate::spartan) fn squares(t: &Scalar, ell: usize) -> Vec { + successors(Some(*t), |p: &Scalar| Some(p.square())) + .take(ell) + .collect::>() + } + + /// Creates the evals corresponding to a `PowPolynomial` from an already-existing vector of powers. + /// `t_pow.len() > ell` must be true. + pub(crate) fn evals_with_powers(powers: &[Scalar], ell: usize) -> Vec { + let t_pow = powers[..ell].to_vec(); + EqPolynomial::evals_from_points(&t_pow) + } + /// Evaluates the `PowPolynomial` at a given point `rx`. /// /// This function computes the value of the polynomial at the point specified by `rx`. @@ -40,8 +50,8 @@ impl PowPolynomial { self.eq.evaluate(rx) } - pub fn coordinates(&self) -> Vec { - self.t_pow.clone() + pub fn coordinates(self) -> Vec { + self.eq.r } /// Evaluates the `PowPolynomial` at all the `2^|t_pow|` points in its domain. diff --git a/src/spartan/polys/univariate.rs b/src/spartan/polys/univariate.rs index bfe983e5b..37a10ba79 100644 --- a/src/spartan/polys/univariate.rs +++ b/src/spartan/polys/univariate.rs @@ -3,15 +3,17 @@ //! - `CompressedUniPoly`: a univariate dense polynomial, compressed (omitted linear term), in coefficient form (little endian), use ff::PrimeField; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; +use ref_cast::RefCast; use serde::{Deserialize, Serialize}; use crate::traits::{Group, TranscriptReprTrait}; // ax^2 + bx + c stored as vec![c, b, a] // ax^3 + bx^2 + cx + d stored as vec![d, c, b, a] -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq, Eq, RefCast)] +#[repr(transparent)] pub struct UniPoly { - coeffs: Vec, + pub coeffs: Vec, } // ax^2 + bx + c stored as vec![c, a] @@ -22,6 +24,7 @@ pub struct CompressedUniPoly { } impl UniPoly { + pub fn from_evals(evals: &[Scalar]) -> Self { // we only support degree-2 or degree-3 univariate polynomials assert!(evals.len() == 3 || evals.len() == 4); diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index bebe6ff54..a43159299 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -27,7 +27,7 @@ use crate::{ snark::{DigestHelperTrait, RelaxedR1CSSNARKTrait}, Engine, TranscriptEngineTrait, TranscriptReprTrait, }, - Commitment, CommitmentKey, CompressedCommitment, + zip_with, Commitment, CommitmentKey, CompressedCommitment, }; use core::cmp::max; use ff::Field; @@ -35,6 +35,8 @@ use once_cell::sync::OnceCell; use rayon::prelude::*; use serde::{Deserialize, Serialize}; +use super::polys::masked_eq::MaskedEqPolynomial; + fn padded(v: &[E::Scalar], n: usize, e: &E::Scalar) -> Vec { let mut v_padded = vec![*e; n]; for (i, v_i) in v.iter().enumerate() { @@ -47,36 +49,36 @@ fn padded(v: &[E::Scalar], n: usize, e: &E::Scalar) -> Vec #[derive(Clone, Serialize, Deserialize)] #[serde(bound = "")] pub struct R1CSShapeSparkRepr { - N: usize, // size of the vectors + pub(in crate::spartan) N: usize, // size of the vectors // dense representation - row: Vec, - col: Vec, - val_A: Vec, - val_B: Vec, - val_C: Vec, + pub(in crate::spartan) row: Vec, + pub(in crate::spartan) col: Vec, + pub(in crate::spartan) val_A: Vec, + pub(in crate::spartan) val_B: Vec, + pub(in crate::spartan) val_C: Vec, // timestamp polynomials - ts_row: Vec, - ts_col: Vec, + pub(in crate::spartan) ts_row: Vec, + pub(in crate::spartan) ts_col: Vec, } /// A type that holds a commitment to a sparse polynomial #[derive(Clone, Serialize, Deserialize)] #[serde(bound = "")] pub struct R1CSShapeSparkCommitment { - N: usize, // size of each vector + pub(in crate::spartan) N: usize, // size of each vector // commitments to the dense representation - comm_row: Commitment, - comm_col: Commitment, - comm_val_A: Commitment, - comm_val_B: Commitment, - comm_val_C: Commitment, + pub(in crate::spartan) comm_row: Commitment, + pub(in crate::spartan) comm_col: Commitment, + pub(in crate::spartan) comm_val_A: Commitment, + pub(in crate::spartan) comm_val_B: Commitment, + pub(in crate::spartan) comm_val_C: Commitment, // commitments to the timestamp polynomials - comm_ts_row: Commitment, - comm_ts_col: Commitment, + pub(in crate::spartan) comm_ts_row: Commitment, + pub(in crate::spartan) comm_ts_col: Commitment, } impl TranscriptReprTrait for R1CSShapeSparkCommitment { @@ -172,7 +174,7 @@ impl R1CSShapeSparkRepr { } } - fn commit(&self, ck: &CommitmentKey) -> R1CSShapeSparkCommitment { + pub(in crate::spartan) fn commit(&self, ck: &CommitmentKey) -> R1CSShapeSparkCommitment { let comm_vec: Vec> = [ &self.row, &self.col, @@ -199,7 +201,7 @@ impl R1CSShapeSparkRepr { } // computes evaluation oracles - fn evaluation_oracles( + pub(in crate::spartan) fn evaluation_oracles( &self, S: &R1CSShape, r_x: &E::Scalar, @@ -256,7 +258,85 @@ pub trait SumcheckEngine: Send + Sync { fn final_claims(&self) -> Vec>; } -struct MemorySumcheckInstance { +/// The [WitnessBoundSumcheck] ensures that the witness polynomial W defined over n = log(N) variables, +/// is zero outside of the first `num_vars = 2^m` entries. +/// +/// # Details +/// +/// The `W` polynomial is padded with zeros to size N = 2^n. +/// The `masked_eq` polynomials is defined as with regards to a random challenge `tau` as +/// the eq(tau) polynomial, where the first 2^m evaluations to 0. +/// +/// The instance is given by +/// `0 = ∑_{0≤i<2^n} masked_eq[i] * W[i]`. +/// It is equivalent to the expression +/// `0 = ∑_{2^m≤i<2^n} eq[i] * W[i]` +/// Since `eq` is random, the instance is only satisfied if `W[2^{m}..] = 0`. +pub(in crate::spartan) struct WitnessBoundSumcheck { + poly_W: MultilinearPolynomial, + poly_masked_eq: MultilinearPolynomial, +} + +impl WitnessBoundSumcheck { + pub fn new(tau: E::Scalar, poly_W_padded: Vec, num_vars: usize) -> Self { + let num_vars_log = num_vars.log_2(); + // When num_vars = num_rounds, we shouldn't have to prove anything + // but we still want this instance to compute the evaluation of W + let num_rounds = poly_W_padded.len().log_2(); + assert!(num_vars_log < num_rounds); + + let tau_coords = PowPolynomial::new(&tau, num_rounds).coordinates(); + let poly_masked_eq_evals = + MaskedEqPolynomial::new(&EqPolynomial::new(tau_coords), num_vars_log).evals(); + + Self { + poly_W: MultilinearPolynomial::new(poly_W_padded), + poly_masked_eq: MultilinearPolynomial::new(poly_masked_eq_evals), + } + } +} +impl SumcheckEngine for WitnessBoundSumcheck { + fn initial_claims(&self) -> Vec { + vec![E::Scalar::ZERO] + } + + fn degree(&self) -> usize { + 3 + } + + fn size(&self) -> usize { + assert_eq!(self.poly_W.len(), self.poly_masked_eq.len()); + self.poly_W.len() + } + + fn evaluation_points(&self) -> Vec> { + let comb_func = |poly_A_comp: &E::Scalar, + poly_B_comp: &E::Scalar, + _: &E::Scalar| + -> E::Scalar { *poly_A_comp * *poly_B_comp }; + + let (eval_point_0, eval_point_2, eval_point_3) = SumcheckProof::::compute_eval_points_cubic( + &self.poly_masked_eq, + &self.poly_W, + &self.poly_W, // unused + &comb_func, + ); + + vec![vec![eval_point_0, eval_point_2, eval_point_3]] + } + + fn bound(&mut self, r: &E::Scalar) { + [&mut self.poly_W, &mut self.poly_masked_eq] + .par_iter_mut() + .for_each(|poly| poly.bind_poly_var_top(r)); + } + + fn final_claims(&self) -> Vec> { + vec![vec![self.poly_W[0], self.poly_masked_eq[0]]] + } +} + +pub(in crate::spartan) struct MemorySumcheckInstance { // row w_plus_r_row: MultilinearPolynomial, t_plus_r_row: MultilinearPolynomial, @@ -279,17 +359,65 @@ struct MemorySumcheckInstance { } impl MemorySumcheckInstance { - pub fn new( + /// Computes witnesses for MemoryInstanceSumcheck + /// + /// # Description + /// We use the logUp protocol to prove that + /// ∑ TS[i]/(T[i] + r) - 1/(W[i] + r) = 0 + /// where + /// T_row[i] = mem_row[i] * gamma + i + /// = eq(tau)[i] * gamma + i + /// W_row[i] = L_row[i] * gamma + addr_row[i] + /// = eq(tau)[row[i]] * gamma + addr_row[i] + /// T_col[i] = mem_col[i] * gamma + i + /// = z[i] * gamma + i + /// W_col[i] = addr_col[i] * gamma + addr_col[i] + /// = z[col[i]] * gamma + addr_col[i] + /// and + /// TS_row, TS_col are integer-valued vectors representing the number of reads + /// to each memory cell of L_row, L_col + /// + /// The function returns oracles for the polynomials TS[i]/(T[i] + r), 1/(W[i] + r), + /// as well as auxiliary polynomials T[i] + r, W[i] + r + pub fn compute_oracles( ck: &CommitmentKey, r: &E::Scalar, - T_row: &[E::Scalar], - W_row: &[E::Scalar], - ts_row: Vec, - T_col: &[E::Scalar], - W_col: &[E::Scalar], - ts_col: Vec, - transcript: &mut E::TE, - ) -> Result<(Self, [Commitment; 4], [Vec; 4]), NovaError> { + gamma: &E::Scalar, + mem_row: &[E::Scalar], + addr_row: &[E::Scalar], + L_row: &[E::Scalar], + ts_row: &[E::Scalar], + mem_col: &[E::Scalar], + addr_col: &[E::Scalar], + L_col: &[E::Scalar], + ts_col: &[E::Scalar], + ) -> Result<([Commitment; 4], [Vec; 4], [Vec; 4]), NovaError> { + // hash the tuples of (addr,val) memory contents and read responses into a single field element using `hash_func` + let hash_func_vec = |mem: &[E::Scalar], + addr: &[E::Scalar], + lookups: &[E::Scalar]| + -> (Vec, Vec) { + let hash_func = |addr: &E::Scalar, val: &E::Scalar| -> E::Scalar { *val * gamma + *addr }; + assert_eq!(addr.len(), lookups.len()); + rayon::join( + || { + (0..mem.len()) + .map(|i| hash_func(&E::Scalar::from(i as u64), &mem[i])) + .collect::>() + }, + || { + (0..addr.len()) + .map(|i| hash_func(&addr[i], &lookups[i])) + .collect::>() + }, + ) + }; + + let ((T_row, W_row), (T_col, W_col)) = rayon::join( + || hash_func_vec(mem_row, addr_row, L_row), + || hash_func_vec(mem_col, addr_col, L_col), + ); + let batch_invert = |v: &[E::Scalar]| -> Result, NovaError> { let mut products = vec![E::Scalar::ZERO; v.len()]; let mut acc = E::Scalar::ONE; @@ -340,10 +468,7 @@ impl MemorySumcheckInstance { // compute inv[i] * TS[i] in parallel Ok( - inv - .par_iter() - .zip(TS.par_iter()) - .map(|(e1, e2)| *e1 * *e2) + zip_with!((inv.into_par_iter(), TS.par_iter()), |e1, e2| e1 * *e2) .collect::>(), ) }, @@ -363,8 +488,8 @@ impl MemorySumcheckInstance { ((t_plus_r_inv_row, w_plus_r_inv_row), (t_plus_r_row, w_plus_r_row)), ((t_plus_r_inv_col, w_plus_r_inv_col), (t_plus_r_col, w_plus_r_col)), ) = rayon::join( - || helper(T_row, W_row, &ts_row, r), - || helper(T_col, W_col, &ts_col, r), + || helper(&T_row, &W_row, ts_row, r), + || helper(&T_col, &W_col, ts_col, r), ); let t_plus_r_inv_row = t_plus_r_inv_row?; @@ -390,21 +515,6 @@ impl MemorySumcheckInstance { }, ); - // absorb the commitments - transcript.absorb( - b"l", - &[ - comm_t_plus_r_inv_row, - comm_w_plus_r_inv_row, - comm_t_plus_r_inv_col, - comm_w_plus_r_inv_col, - ] - .as_slice(), - ); - - let rho = transcript.squeeze(b"r")?; - let poly_eq = MultilinearPolynomial::new(PowPolynomial::new(&rho, T_row.len().log_2()).evals()); - let comm_vec = [ comm_t_plus_r_inv_row, comm_w_plus_r_inv_row, @@ -413,32 +523,43 @@ impl MemorySumcheckInstance { ]; let poly_vec = [ - t_plus_r_inv_row.clone(), - w_plus_r_inv_row.clone(), - t_plus_r_inv_col.clone(), - w_plus_r_inv_col.clone(), + t_plus_r_inv_row, + w_plus_r_inv_row, + t_plus_r_inv_col, + w_plus_r_inv_col, ]; - let zero = vec![E::Scalar::ZERO; t_plus_r_inv_row.len()]; + let aux_poly_vec = [t_plus_r_row?, w_plus_r_row?, t_plus_r_col?, w_plus_r_col?]; - Ok(( - Self { - w_plus_r_row: MultilinearPolynomial::new(w_plus_r_row?), - t_plus_r_row: MultilinearPolynomial::new(t_plus_r_row?), - t_plus_r_inv_row: MultilinearPolynomial::new(t_plus_r_inv_row), - w_plus_r_inv_row: MultilinearPolynomial::new(w_plus_r_inv_row), - ts_row: MultilinearPolynomial::new(ts_row), - w_plus_r_col: MultilinearPolynomial::new(w_plus_r_col?), - t_plus_r_col: MultilinearPolynomial::new(t_plus_r_col?), - t_plus_r_inv_col: MultilinearPolynomial::new(t_plus_r_inv_col), - w_plus_r_inv_col: MultilinearPolynomial::new(w_plus_r_inv_col), - ts_col: MultilinearPolynomial::new(ts_col), - poly_eq, - poly_zero: MultilinearPolynomial::new(zero), - }, - comm_vec, - poly_vec, - )) + Ok((comm_vec, poly_vec, aux_poly_vec)) + } + + pub fn new( + polys_oracle: [Vec; 4], + polys_aux: [Vec; 4], + poly_eq: Vec, + ts_row: Vec, + ts_col: Vec, + ) -> Self { + let [t_plus_r_inv_row, w_plus_r_inv_row, t_plus_r_inv_col, w_plus_r_inv_col] = polys_oracle; + let [t_plus_r_row, w_plus_r_row, t_plus_r_col, w_plus_r_col] = polys_aux; + + let zero = vec![E::Scalar::ZERO; poly_eq.len()]; + + Self { + w_plus_r_row: MultilinearPolynomial::new(w_plus_r_row), + t_plus_r_row: MultilinearPolynomial::new(t_plus_r_row), + t_plus_r_inv_row: MultilinearPolynomial::new(t_plus_r_inv_row), + w_plus_r_inv_row: MultilinearPolynomial::new(w_plus_r_inv_row), + ts_row: MultilinearPolynomial::new(ts_row), + w_plus_r_col: MultilinearPolynomial::new(w_plus_r_col), + t_plus_r_col: MultilinearPolynomial::new(t_plus_r_col), + t_plus_r_inv_col: MultilinearPolynomial::new(t_plus_r_inv_col), + w_plus_r_inv_col: MultilinearPolynomial::new(w_plus_r_inv_col), + ts_col: MultilinearPolynomial::new(ts_col), + poly_eq: MultilinearPolynomial::new(poly_eq), + poly_zero: MultilinearPolynomial::new(zero), + } } } @@ -483,6 +604,7 @@ impl SumcheckEngine for MemorySumcheckInstance { -> E::Scalar { *poly_A_comp * (*poly_B_comp * *poly_C_comp - *poly_D_comp) }; // inv related evaluation points + // 0 = ∑ TS[i]/(T[i] + r) - 1/(W[i] + r) let (eval_inv_0_row, eval_inv_2_row, eval_inv_3_row) = SumcheckProof::::compute_eval_points_cubic( &self.t_plus_r_inv_row, @@ -500,6 +622,7 @@ impl SumcheckEngine for MemorySumcheckInstance { ); // row related evaluation points + // 0 = ∑ eq[i] * (inv_T[i] * (T[i] + r) - TS[i])) let (eval_T_0_row, eval_T_2_row, eval_T_3_row) = SumcheckProof::::compute_eval_points_cubic_with_additive_term( &self.poly_eq, @@ -508,6 +631,7 @@ impl SumcheckEngine for MemorySumcheckInstance { &self.ts_row, &comb_func3, ); + // 0 = ∑ eq[i] * (inv_W[i] * (T[i] + r) - 1)) let (eval_W_0_row, eval_W_2_row, eval_W_3_row) = SumcheckProof::::compute_eval_points_cubic_with_additive_term( &self.poly_eq, @@ -580,7 +704,7 @@ impl SumcheckEngine for MemorySumcheckInstance { } } -struct OuterSumcheckInstance { +pub(in crate::spartan) struct OuterSumcheckInstance { poly_tau: MultilinearPolynomial, poly_Az: MultilinearPolynomial, poly_Bz: MultilinearPolynomial, @@ -684,13 +808,27 @@ impl SumcheckEngine for OuterSumcheckInstance { } } -struct InnerSumcheckInstance { +pub(in crate::spartan) struct InnerSumcheckInstance { claim: E::Scalar, poly_L_row: MultilinearPolynomial, poly_L_col: MultilinearPolynomial, poly_val: MultilinearPolynomial, } - +impl InnerSumcheckInstance { + pub fn new( + claim: E::Scalar, + poly_L_row: MultilinearPolynomial, + poly_L_col: MultilinearPolynomial, + poly_val: MultilinearPolynomial, + ) -> Self { + Self { + claim, + poly_L_row, + poly_L_col, + poly_val, + } + } +} impl SumcheckEngine for InnerSumcheckInstance { fn initial_claims(&self) -> Vec { vec![self.claim] @@ -761,7 +899,7 @@ impl> SimpleDigestible for VerifierKey> { // commitment to oracles: the first three are for Az, Bz, Cz, @@ -810,15 +948,16 @@ pub struct RelaxedR1CSSNARK> { eval_ts_col: E::Scalar, // a PCS evaluation argument - eval_arg_W: EE::EvaluationArgument, - eval_arg_batch: EE::EvaluationArgument, + eval_arg: EE::EvaluationArgument, } -impl> RelaxedR1CSSNARK { - fn prove_helper( +impl> RelaxedR1CSSNARK +{ + fn prove_helper( mem: &mut T1, outer: &mut T2, inner: &mut T3, + witness: &mut T4, transcript: &mut E::TE, ) -> Result< ( @@ -827,6 +966,7 @@ impl> RelaxedR1CSSNARK { Vec>, Vec>, Vec>, + Vec>, ), NovaError, > @@ -834,12 +974,15 @@ impl> RelaxedR1CSSNARK { T1: SumcheckEngine, T2: SumcheckEngine, T3: SumcheckEngine, + T4: SumcheckEngine, { // sanity checks assert_eq!(mem.size(), outer.size()); assert_eq!(mem.size(), inner.size()); + assert_eq!(mem.size(), witness.size()); assert_eq!(mem.degree(), outer.degree()); assert_eq!(mem.degree(), inner.degree()); + assert_eq!(mem.degree(), witness.degree()); // these claims are already added to the transcript, so we do not need to add let claims = mem @@ -847,32 +990,30 @@ impl> RelaxedR1CSSNARK { .into_iter() .chain(outer.initial_claims()) .chain(inner.initial_claims()) + .chain(witness.initial_claims()) .collect::>(); let s = transcript.squeeze(b"r")?; let coeffs = powers::(&s, claims.len()); // compute the joint claim - let claim = claims - .iter() - .zip(coeffs.iter()) - .map(|(c_1, c_2)| *c_1 * c_2) - .sum(); + let claim = zip_with!((claims.iter(), coeffs.iter()), |c_1, c_2| *c_1 * c_2).sum(); let mut e = claim; let mut r: Vec = Vec::new(); let mut cubic_polys: Vec> = Vec::new(); let num_rounds = mem.size().log_2(); for _ in 0..num_rounds { - let (evals_mem, (evals_outer, evals_inner)) = rayon::join( - || mem.evaluation_points(), - || rayon::join(|| outer.evaluation_points(), || inner.evaluation_points()), + let ((evals_mem, evals_outer), (evals_inner, evals_witness)) = rayon::join( + || rayon::join(|| mem.evaluation_points(), || outer.evaluation_points()), + || rayon::join(|| inner.evaluation_points(), || witness.evaluation_points()), ); let evals: Vec> = evals_mem .into_iter() .chain(evals_outer.into_iter()) .chain(evals_inner.into_iter()) + .chain(evals_witness.into_iter()) .collect::>>(); assert_eq!(evals.len(), claims.len()); @@ -896,8 +1037,8 @@ impl> RelaxedR1CSSNARK { r.push(r_i); let _ = rayon::join( - || mem.bound(&r_i), - || rayon::join(|| outer.bound(&r_i), || inner.bound(&r_i)), + || rayon::join(|| mem.bound(&r_i), || outer.bound(&r_i)), + || rayon::join(|| inner.bound(&r_i), || witness.bound(&r_i)), ); e = poly.evaluate(&r_i); @@ -907,6 +1048,7 @@ impl> RelaxedR1CSSNARK { let mem_claims = mem.final_claims(); let outer_claims = outer.final_claims(); let inner_claims = inner.final_claims(); + let witness_claims = witness.final_claims(); Ok(( SumcheckProof::new(cubic_polys), @@ -914,6 +1056,7 @@ impl> RelaxedR1CSSNARK { mem_claims, outer_claims, inner_claims, + witness_claims, )) } } @@ -948,7 +1091,8 @@ impl> DigestHelperTrait for VerifierK } } -impl> RelaxedR1CSSNARKTrait for RelaxedR1CSSNARK { +impl> RelaxedR1CSSNARKTrait for RelaxedR1CSSNARK +{ type ProverKey = ProverKey; type VerifierKey = VerifierKey; @@ -1027,13 +1171,14 @@ impl> RelaxedR1CSSNARKTrait for Relax let tau_coords = PowPolynomial::new(&tau, num_rounds_sc).coordinates(); // (1) send commitments to Az, Bz, and Cz along with their evaluations at tau - let (Az, Bz, Cz, E) = { + let (Az, Bz, Cz, W, E) = { Az.resize(pk.S_repr.N, E::Scalar::ZERO); Bz.resize(pk.S_repr.N, E::Scalar::ZERO); Cz.resize(pk.S_repr.N, E::Scalar::ZERO); let E = padded::(&W.E, pk.S_repr.N, &E::Scalar::ZERO); + let W = padded::(&W.W, pk.S_repr.N, &E::Scalar::ZERO); - (Az, Bz, Cz, E) + (Az, Bz, Cz, W, E) }; let (eval_Az_at_tau, eval_Bz_at_tau, eval_Cz_at_tau) = { let evals_at_tau = [&Az, &Bz, &Cz] @@ -1094,8 +1239,8 @@ impl> RelaxedR1CSSNARKTrait for Relax .S_repr .val_A .par_iter() - .zip(pk.S_repr.val_B.par_iter()) - .zip(pk.S_repr.val_C.par_iter()) + .zip_eq(pk.S_repr.val_B.par_iter()) + .zip_eq(pk.S_repr.val_C.par_iter()) .map(|((v_a, v_b), v_c)| *v_a + c * *v_b + c * c * *v_c) .collect::>(); let inner_sc_inst = InnerSumcheckInstance { @@ -1112,51 +1257,50 @@ impl> RelaxedR1CSSNARKTrait for Relax // we now need to prove that L_row and L_col are well-formed // hash the tuples of (addr,val) memory contents and read responses into a single field element using `hash_func` - let hash_func_vec = |mem: &[E::Scalar], - addr: &[E::Scalar], - lookups: &[E::Scalar]| - -> (Vec, Vec) { - let hash_func = |addr: &E::Scalar, val: &E::Scalar| -> E::Scalar { *val * gamma + *addr }; - assert_eq!(addr.len(), lookups.len()); - rayon::join( - || { - (0..mem.len()) - .map(|i| hash_func(&E::Scalar::from(i as u64), &mem[i])) - .collect::>() - }, - || { - (0..addr.len()) - .map(|i| hash_func(&addr[i], &lookups[i])) - .collect::>() - }, - ) - }; - let ((T_row, W_row), (T_col, W_col)) = rayon::join( - || hash_func_vec(&mem_row, &pk.S_repr.row, &L_row), - || hash_func_vec(&mem_col, &pk.S_repr.col, &L_col), - ); - - MemorySumcheckInstance::new( - ck, - &r, - &T_row, - &W_row, - pk.S_repr.ts_row.clone(), - &T_col, - &W_col, - pk.S_repr.ts_col.clone(), - &mut transcript, - ) + let (comm_mem_oracles, mem_oracles, mem_aux) = + MemorySumcheckInstance::::compute_oracles( + ck, + &r, + &gamma, + &mem_row, + &pk.S_repr.row, + &L_row, + &pk.S_repr.ts_row, + &mem_col, + &pk.S_repr.col, + &L_col, + &pk.S_repr.ts_col, + )?; + // absorb the commitments + transcript.absorb(b"l", &comm_mem_oracles.as_slice()); + + let rho = transcript.squeeze(b"r")?; + let poly_eq = MultilinearPolynomial::new(PowPolynomial::new(&rho, num_rounds_sc).evals()); + + Ok::<_, NovaError>(( + MemorySumcheckInstance::new( + mem_oracles.clone(), + mem_aux, + poly_eq.Z, + pk.S_repr.ts_row.clone(), + pk.S_repr.ts_col.clone(), + ), + comm_mem_oracles, + mem_oracles, + )) }, ); let (mut mem_sc_inst, comm_mem_oracles, mem_oracles) = mem_res?; - let (sc, rand_sc, claims_mem, claims_outer, claims_inner) = Self::prove_helper( + let mut witness_sc_inst = WitnessBoundSumcheck::new(tau, W.clone(), S.num_vars); + + let (sc, rand_sc, claims_mem, claims_outer, claims_inner, claims_witness) = Self::prove_helper( &mut mem_sc_inst, &mut outer_sc_inst, &mut inner_sc_inst, + &mut witness_sc_inst, &mut transcript, )?; @@ -1174,6 +1318,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let eval_t_plus_r_inv_col = claims_mem[1][0]; let eval_w_plus_r_inv_col = claims_mem[1][1]; let eval_ts_col = claims_mem[1][2]; + let eval_W = claims_witness[0][0]; // compute the remaining claims that did not come for free from the sum-check prover let (eval_Cz, eval_E, eval_val_A, eval_val_B, eval_val_C, eval_row, eval_col) = { @@ -1192,8 +1337,9 @@ impl> RelaxedR1CSSNARKTrait for Relax (e[0], e[1], e[2], e[3], e[4], e[5], e[6]) }; - // all the following evaluations are at rand_sc, we can fold them into one claim + // all the evaluations are at rand_sc, we can fold them into one claim let eval_vec = vec![ + eval_W, eval_Az, eval_Bz, eval_Cz, @@ -1216,6 +1362,7 @@ impl> RelaxedR1CSSNARKTrait for Relax .collect::>(); let comm_vec = [ + U.comm_W, comm_Az, comm_Bz, comm_Cz, @@ -1235,6 +1382,7 @@ impl> RelaxedR1CSSNARKTrait for Relax pk.S_comm.comm_ts_col, ]; let poly_vec = [ + &W, &Az, &Bz, &Cz, @@ -1258,21 +1406,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let w: PolyEvalWitness = PolyEvalWitness::batch(&poly_vec, &c); let u: PolyEvalInstance = PolyEvalInstance::batch(&comm_vec, &rand_sc, &eval_vec, &c); - let eval_arg_batch = EE::prove(ck, &pk.pk_ee, &mut transcript, &u.c, &w.p, &rand_sc, &u.e)?; - - // prove eval_W at the shortened vector - let l = pk.S_comm.N.log_2() - (2 * S.num_vars).log_2(); - let rand_sc_unpad = rand_sc[l..].to_vec(); - let eval_W = MultilinearPolynomial::evaluate_with(&W.W, &rand_sc_unpad[1..]); - let eval_arg_W = EE::prove( - ck, - &pk.pk_ee, - &mut transcript, - &U.comm_W, - &W.W, - &rand_sc_unpad[1..], - &eval_W, - )?; + let eval_arg = EE::prove(ck, &pk.pk_ee, &mut transcript, &u.c, &w.p, &rand_sc, &u.e)?; Ok(RelaxedR1CSSNARK { comm_Az: comm_Az.compress(), @@ -1314,8 +1448,7 @@ impl> RelaxedR1CSSNARKTrait for Relax eval_w_plus_r_inv_col, eval_ts_col, - eval_arg_batch, - eval_arg_W, + eval_arg, }) } @@ -1386,7 +1519,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let rho = transcript.squeeze(b"r")?; - let num_claims = 9; + let num_claims = 10; let s = transcript.squeeze(b"r")?; let coeffs = powers::(&s, num_claims); let claim = (coeffs[7] + coeffs[8]) * claim; // rest are zeros @@ -1400,7 +1533,12 @@ impl> RelaxedR1CSSNARKTrait for Relax let poly_eq_coords = PowPolynomial::new(&rho, num_rounds_sc).coordinates(); EqPolynomial::new(poly_eq_coords).evaluate(&rand_sc) }; - let taus_bound_rand_sc = PowPolynomial::new(&tau, num_rounds_sc).evaluate(&rand_sc); + let taus_coords = PowPolynomial::new(&tau, num_rounds_sc).coordinates(); + let eq_tau = EqPolynomial::new(taus_coords); + + let taus_bound_rand_sc = eq_tau.evaluate(&rand_sc); + let taus_masked_bound_rand_sc = + MaskedEqPolynomial::new(&eq_tau, vk.num_vars.log_2()).evaluate(&rand_sc); let eval_t_plus_r_row = { let eval_addr_row = IdentityPolynomial::new(num_rounds_sc).evaluate(&rand_sc); @@ -1430,10 +1568,7 @@ impl> RelaxedR1CSSNARKTrait for Relax factor *= E::Scalar::ONE - r_p } - let rand_sc_unpad = { - let l = vk.S_comm.N.log_2() - (2 * vk.num_vars).log_2(); - rand_sc[l..].to_vec() - }; + let rand_sc_unpad = rand_sc[l..].to_vec(); (factor, rand_sc_unpad) }; @@ -1450,7 +1585,7 @@ impl> RelaxedR1CSSNARKTrait for Relax SparsePolynomial::new(vk.num_vars.log_2(), poly_X).evaluate(&rand_sc_unpad[1..]) }; - factor * ((E::Scalar::ONE - rand_sc_unpad[0]) * self.eval_W + rand_sc_unpad[0] * eval_X) + self.eval_W + factor * rand_sc_unpad[0] * eval_X }; let eval_t = eval_addr_col + gamma * eval_val_col; eval_t + r @@ -1488,7 +1623,12 @@ impl> RelaxedR1CSSNARKTrait for Relax * self.eval_L_col * (self.eval_val_A + c * self.eval_val_B + c * c * self.eval_val_C); - claim_mem_final_expected + claim_outer_final_expected + claim_inner_final_expected + let claim_witness_final_expected = coeffs[9] * taus_masked_bound_rand_sc * self.eval_W; + + claim_mem_final_expected + + claim_outer_final_expected + + claim_inner_final_expected + + claim_witness_final_expected }; if claim_sc_final_expected != claim_sc_final { @@ -1496,6 +1636,7 @@ impl> RelaxedR1CSSNARKTrait for Relax } let eval_vec = vec![ + self.eval_W, self.eval_Az, self.eval_Bz, self.eval_Cz, @@ -1517,6 +1658,7 @@ impl> RelaxedR1CSSNARKTrait for Relax .into_iter() .collect::>(); let comm_vec = [ + U.comm_W, comm_Az, comm_Bz, comm_Cz, @@ -1539,28 +1681,14 @@ impl> RelaxedR1CSSNARKTrait for Relax let c = transcript.squeeze(b"c")?; let u: PolyEvalInstance = PolyEvalInstance::batch(&comm_vec, &rand_sc, &eval_vec, &c); - // verify eval_arg_batch + // verify EE::verify( &vk.vk_ee, &mut transcript, &u.c, &rand_sc, &u.e, - &self.eval_arg_batch, - )?; - - // verify eval_arg_W - let rand_sc_unpad = { - let l = vk.S_comm.N.log_2() - (2 * vk.num_vars).log_2(); - rand_sc[l..].to_vec() - }; - EE::verify( - &vk.vk_ee, - &mut transcript, - &U.comm_W, - &rand_sc_unpad[1..], - &self.eval_W, - &self.eval_arg_W, + &self.eval_arg, )?; Ok(()) diff --git a/src/spartan/snark.rs b/src/spartan/snark.rs index 33d739e61..02b5302d4 100644 --- a/src/spartan/snark.rs +++ b/src/spartan/snark.rs @@ -9,6 +9,7 @@ use crate::{ errors::NovaError, r1cs::{R1CSShape, RelaxedR1CSInstance, RelaxedR1CSWitness, SparseMatrix}, spartan::{ + compute_eval_table_sparse, polys::{eq::EqPolynomial, multilinear::MultilinearPolynomial, multilinear::SparsePolynomial}, powers, sumcheck::SumcheckProof, @@ -19,16 +20,18 @@ use crate::{ snark::{DigestHelperTrait, RelaxedR1CSSNARKTrait}, Engine, TranscriptEngineTrait, }, - Commitment, CommitmentKey, + CommitmentKey, }; + use ff::Field; +use itertools::Itertools as _; use once_cell::sync::OnceCell; use rayon::prelude::*; use serde::{Deserialize, Serialize}; /// A type that represents the prover's key -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[serde(bound = "")] pub struct ProverKey> { pk_ee: EE::ProverKey, @@ -36,7 +39,7 @@ pub struct ProverKey> { } /// A type that represents the verifier's key -#[derive(Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[serde(bound = "")] pub struct VerifierKey> { vk_ee: EE::VerifierKey, @@ -74,7 +77,7 @@ impl> DigestHelperTrait for VerifierK /// A succinct proof of knowledge of a witness to a relaxed R1CS instance /// The proof is produced using Spartan's combination of the sum-check and /// the commitment to a vector viewed as a polynomial commitment -#[derive(Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] #[serde(bound = "")] pub struct RelaxedR1CSSNARK> { sc_proof_outer: SumcheckProof, @@ -87,7 +90,8 @@ pub struct RelaxedR1CSSNARK> { eval_arg: EE::EvaluationArgument, } -impl> RelaxedR1CSSNARKTrait for RelaxedR1CSSNARK { +impl> RelaxedR1CSSNARKTrait for RelaxedR1CSSNARK +{ type ProverKey = ProverKey; type VerifierKey = VerifierKey; @@ -140,9 +144,9 @@ impl> RelaxedR1CSSNARKTrait for Relax // outer sum-check let tau = (0..num_rounds_x) .map(|_i| transcript.squeeze(b"t")) - .collect::, NovaError>>()?; + .collect::, NovaError>>()?; - let mut poly_tau = MultilinearPolynomial::new(EqPolynomial::new(tau).evals()); + let mut poly_tau = MultilinearPolynomial::new(tau.evals()); let (mut poly_Az, mut poly_Bz, poly_Cz, mut poly_uCz_E) = { let (poly_Az, poly_Bz, poly_Cz) = S.multiply_vec(&z)?; let poly_uCz_E = (0..S.num_cons) @@ -188,45 +192,7 @@ impl> RelaxedR1CSSNARKTrait for Relax let poly_ABC = { // compute the initial evaluation table for R(\tau, x) - let evals_rx = EqPolynomial::new(r_x.clone()).evals(); - - // Bounds "row" variables of (A, B, C) matrices viewed as 2d multilinear polynomials - let compute_eval_table_sparse = - |S: &R1CSShape, rx: &[E::Scalar]| -> (Vec, Vec, Vec) { - assert_eq!(rx.len(), S.num_cons); - - let inner = |M: &SparseMatrix, M_evals: &mut Vec| { - for (row_idx, ptrs) in M.indptr.windows(2).enumerate() { - for (val, col_idx) in M.get_row_unchecked(ptrs.try_into().unwrap()) { - M_evals[*col_idx] += rx[row_idx] * val; - } - } - }; - - let (A_evals, (B_evals, C_evals)) = rayon::join( - || { - let mut A_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; - inner(&S.A, &mut A_evals); - A_evals - }, - || { - rayon::join( - || { - let mut B_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; - inner(&S.B, &mut B_evals); - B_evals - }, - || { - let mut C_evals: Vec = vec![E::Scalar::ZERO; 2 * S.num_vars]; - inner(&S.C, &mut C_evals); - C_evals - }, - ) - }, - ); - - (A_evals, B_evals, C_evals) - }; + let evals_rx = EqPolynomial::evals_from_points(&r_x.clone()); let (evals_A, evals_B, evals_C) = compute_eval_table_sparse(&S, &evals_rx); @@ -320,7 +286,7 @@ impl> RelaxedR1CSSNARKTrait for Relax // outer sum-check let tau = (0..num_rounds_x) .map(|_i| transcript.squeeze(b"t")) - .collect::, NovaError>>()?; + .collect::, NovaError>>()?; let (claim_outer_final, r_x) = self @@ -329,7 +295,7 @@ impl> RelaxedR1CSSNARKTrait for Relax // verify claim_outer_final let (claim_Az, claim_Bz, claim_Cz) = self.claims_outer; - let taus_bound_rx = EqPolynomial::new(tau).evaluate(&r_x); + let taus_bound_rx = tau.evaluate(&r_x); let claim_outer_final_expected = taus_bound_rx * (claim_Az * claim_Bz - U.u * claim_Cz - self.eval_E); if claim_outer_final != claim_outer_final_expected { @@ -393,8 +359,8 @@ impl> RelaxedR1CSSNARKTrait for Relax }; let (T_x, T_y) = rayon::join( - || EqPolynomial::new(r_x.to_vec()).evals(), - || EqPolynomial::new(r_y.to_vec()).evals(), + || EqPolynomial::evals_from_points(r_x), + || EqPolynomial::evals_from_points(r_y), ); (0..M_vec.len()) @@ -447,7 +413,19 @@ impl> RelaxedR1CSSNARKTrait for Relax /// Proves a batch of polynomial evaluation claims using Sumcheck /// reducing them to a single claim at the same point. -fn batch_eval_prove( +/// +/// # Details +/// +/// We are given as input a list of instance/witness pairs +/// u = [(Cᵢ, xᵢ, eᵢ)], w = [Pᵢ], such that +/// - nᵢ = |xᵢ| +/// - Cᵢ = Commit(Pᵢ) +/// - eᵢ = Pᵢ(xᵢ) +/// - |Pᵢ| = 2^nᵢ +/// +/// We allow the polynomial Pᵢ to have different sizes, by appropriately scaling +/// the claims and resulting evaluations from Sumcheck. +pub(in crate::spartan) fn batch_eval_prove( u_vec: Vec>, w_vec: Vec>, transcript: &mut E::TE, @@ -460,38 +438,44 @@ fn batch_eval_prove( ), NovaError, > { - assert_eq!(u_vec.len(), w_vec.len()); + let num_claims = u_vec.len(); + assert_eq!(w_vec.len(), num_claims); - let w_vec_padded = PolyEvalWitness::pad(w_vec); // pad the polynomials to be of the same size - let u_vec_padded = PolyEvalInstance::pad(u_vec); // pad the evaluation points + // Compute nᵢ and n = maxᵢ{nᵢ} + let num_rounds = u_vec.iter().map(|u| u.x.len()).collect::>(); - // generate a challenge + // Check polynomials match number of variables, i.e. |Pᵢ| = 2^nᵢ + w_vec + .iter() + .zip_eq(num_rounds.iter()) + .for_each(|(w, num_vars)| assert_eq!(w.p.len(), 1 << num_vars)); + + // generate a challenge, and powers of it for random linear combination let rho = transcript.squeeze(b"r")?; - let num_claims = w_vec_padded.len(); let powers_of_rho = powers::(&rho, num_claims); - let claim_batch_joint = u_vec_padded - .iter() - .zip(powers_of_rho.iter()) - .map(|(u, p)| u.e * p) - .sum(); - let mut polys_left: Vec> = w_vec_padded + let (claims, u_xs, comms): (Vec<_>, Vec<_>, Vec<_>) = + u_vec.into_iter().map(|u| (u.e, u.x, u.c)).multiunzip(); + + // Create clones of polynomials to be given to Sumcheck + // Pᵢ(X) + let polys_P: Vec> = w_vec .iter() .map(|w| MultilinearPolynomial::new(w.p.clone())) .collect(); - let mut polys_right: Vec> = u_vec_padded - .iter() - .map(|u| MultilinearPolynomial::new(EqPolynomial::new(u.x.clone()).evals())) + // eq(xᵢ, X) + let polys_eq: Vec> = u_xs + .into_iter() + .map(|ux| MultilinearPolynomial::new(EqPolynomial::evals_from_points(&ux))) .collect(); - let num_rounds_z = u_vec_padded[0].x.len(); - let comb_func = - |poly_A_comp: &E::Scalar, poly_B_comp: &E::Scalar| -> E::Scalar { *poly_A_comp * *poly_B_comp }; - let (sc_proof_batch, r_z, claims_batch) = SumcheckProof::prove_quad_batch( - &claim_batch_joint, - num_rounds_z, - &mut polys_left, - &mut polys_right, + // For each i, check eᵢ = ∑ₓ Pᵢ(x)eq(xᵢ,x), where x ∈ {0,1}^nᵢ + let comb_func = |poly_P: &E::Scalar, poly_eq: &E::Scalar| -> E::Scalar { *poly_P * *poly_eq }; + let (sc_proof_batch, r, claims_batch) = SumcheckProof::prove_quad_batch( + &claims, + &num_rounds, + polys_P, + polys_eq, &powers_of_rho, comb_func, transcript, @@ -501,72 +485,52 @@ fn batch_eval_prove( transcript.absorb(b"l", &claims_batch_left.as_slice()); - // we now combine evaluation claims at the same point rz into one + // we now combine evaluation claims at the same point r into one let gamma = transcript.squeeze(b"g")?; - let powers_of_gamma: Vec = powers::(&gamma, num_claims); - let comm_joint = u_vec_padded - .iter() - .zip(powers_of_gamma.iter()) - .map(|(u, g_i)| u.c * *g_i) - .fold(Commitment::::default(), |acc, item| acc + item); - let poly_joint = PolyEvalWitness::weighted_sum(&w_vec_padded, &powers_of_gamma); - let eval_joint = claims_batch_left - .iter() - .zip(powers_of_gamma.iter()) - .map(|(e, g_i)| *e * *g_i) - .sum(); - - Ok(( - PolyEvalInstance:: { - c: comm_joint, - x: r_z, - e: eval_joint, - }, - poly_joint, - sc_proof_batch, - claims_batch_left, - )) + + let u_joint = + PolyEvalInstance::batch_diff_size(&comms, &claims_batch_left, &num_rounds, r, gamma); + + // P = ∑ᵢ γⁱ⋅Pᵢ + let w_joint = PolyEvalWitness::batch_diff_size(w_vec, gamma); + + Ok((u_joint, w_joint, sc_proof_batch, claims_batch_left)) } /// Verifies a batch of polynomial evaluation claims using Sumcheck /// reducing them to a single claim at the same point. -fn batch_eval_verify( +pub(in crate::spartan) fn batch_eval_verify( u_vec: Vec>, transcript: &mut E::TE, sc_proof_batch: &SumcheckProof, evals_batch: &[E::Scalar], ) -> Result, NovaError> { - assert_eq!(evals_batch.len(), evals_batch.len()); - - let u_vec_padded = PolyEvalInstance::pad(u_vec); // pad the evaluation points + let num_claims = u_vec.len(); + assert_eq!(evals_batch.len(), num_claims); // generate a challenge let rho = transcript.squeeze(b"r")?; - let num_claims: usize = u_vec_padded.len(); let powers_of_rho = powers::(&rho, num_claims); - let claim_batch_joint = u_vec_padded - .iter() - .zip(powers_of_rho.iter()) - .map(|(u, p)| u.e * p) - .sum(); - let num_rounds_z = u_vec_padded[0].x.len(); + // Compute nᵢ and n = maxᵢ{nᵢ} + let num_rounds = u_vec.iter().map(|u| u.x.len()).collect::>(); + let num_rounds_max = *num_rounds.iter().max().unwrap(); + + let claims = u_vec.iter().map(|u| u.e).collect::>(); - let (claim_batch_final, r_z) = - sc_proof_batch.verify(claim_batch_joint, num_rounds_z, 2, transcript)?; + let (claim_batch_final, r) = + sc_proof_batch.verify_batch(&claims, &num_rounds, &powers_of_rho, 2, transcript)?; let claim_batch_final_expected = { - let poly_rz = EqPolynomial::new(r_z.clone()); - let evals = u_vec_padded - .iter() - .map(|u| poly_rz.evaluate(&u.x)) - .collect::>(); - - evals - .iter() - .zip(evals_batch.iter()) - .zip(powers_of_rho.iter()) - .map(|((e_i, p_i), rho_i)| *e_i * *p_i * rho_i) + let evals_r = u_vec.iter().map(|u| { + let (_, r_hi) = r.split_at(num_rounds_max - u.x.len()); + EqPolynomial::new(r_hi.to_vec()).evaluate(&u.x) + }); + + evals_r + .zip_eq(evals_batch.iter()) + .zip_eq(powers_of_rho.iter()) + .map(|((e_i, p_i), rho_i)| e_i * *p_i * rho_i) .sum() }; @@ -576,23 +540,12 @@ fn batch_eval_verify( transcript.absorb(b"l", &evals_batch); - // we now combine evaluation claims at the same point rz into one + // we now combine evaluation claims at the same point r into one let gamma = transcript.squeeze(b"g")?; - let powers_of_gamma: Vec = powers::(&gamma, num_claims); - let comm_joint = u_vec_padded - .iter() - .zip(powers_of_gamma.iter()) - .map(|(u, g_i)| u.c * *g_i) - .fold(Commitment::::default(), |acc, item| acc + item); - let eval_joint = evals_batch - .iter() - .zip(powers_of_gamma.iter()) - .map(|(e, g_i)| *e * *g_i) - .sum(); - - Ok(PolyEvalInstance:: { - c: comm_joint, - x: r_z, - e: eval_joint, - }) + + let comms = u_vec.into_iter().map(|u| u.c).collect::>(); + + let u_joint = PolyEvalInstance::batch_diff_size(&comms, evals_batch, &num_rounds, r, gamma); + + Ok(u_joint) } diff --git a/src/spartan/sumcheck.rs b/src/spartan/sumcheck.rs index fdf168bef..d8d603ca4 100644 --- a/src/spartan/sumcheck.rs +++ b/src/spartan/sumcheck.rs @@ -61,6 +61,40 @@ impl SumcheckProof { Ok((e, r)) } + pub fn verify_batch( + &self, + claims: &[E::Scalar], + num_rounds: &[usize], + coeffs: &[E::Scalar], + degree_bound: usize, + transcript: &mut E::TE, + ) -> Result<(E::Scalar, Vec), NovaError> { + let num_instances = claims.len(); + assert_eq!(num_rounds.len(), num_instances); + assert_eq!(coeffs.len(), num_instances); + + // n = maxᵢ{nᵢ} + let num_rounds_max = *num_rounds.iter().max().unwrap(); + + // Random linear combination of claims, + // where each claim is scaled by 2^{n-nᵢ} to account for the padding. + // + // claim = ∑ᵢ coeffᵢ⋅2^{n-nᵢ}⋅cᵢ + let claim = zip_with!( + ( + zip_with!(iter, (claims, num_rounds), |claim, num_rounds| { + let scaling_factor = 1 << (num_rounds_max - num_rounds); + E::Scalar::from(scaling_factor as u64) * claim + }), + coeffs.iter() + ), + |scaled_claim, coeff| scaled_claim * coeff + ) + .sum(); + + self.verify(claim, num_rounds_max, degree_bound, transcript) + } + #[inline] pub(in crate::spartan) fn compute_eval_points_quad( poly_A: &MultilinearPolynomial, @@ -123,7 +157,7 @@ impl SumcheckProof { // Set up next round claim_per_round = poly.evaluate(&r_i); - // bound all tables to the verifier's challenege + // bind all tables to the verifier's challenge rayon::join( || poly_A.bind_poly_var_top(&r_i), || poly_B.bind_poly_var_top(&r_i), @@ -140,10 +174,10 @@ impl SumcheckProof { } pub fn prove_quad_batch( - claim: &E::Scalar, - num_rounds: usize, - poly_A_vec: &mut Vec>, - poly_B_vec: &mut Vec>, + claims: &[E::Scalar], + num_rounds: &[usize], + mut poly_A_vec: Vec>, + mut poly_B_vec: Vec>, coeffs: &[E::Scalar], comb_func: F, transcript: &mut E::TE, @@ -151,16 +185,58 @@ impl SumcheckProof { where F: Fn(&E::Scalar, &E::Scalar) -> E::Scalar + Sync, { - let mut e = *claim; + let num_claims = claims.len(); + + assert_eq!(num_rounds.len(), num_claims); + assert_eq!(poly_A_vec.len(), num_claims); + assert_eq!(poly_B_vec.len(), num_claims); + assert_eq!(coeffs.len(), num_claims); + + for (i, &num_rounds) in num_rounds.iter().enumerate() { + let expected_size = 1 << num_rounds; + + // Direct indexing with the assumption that the index will always be in bounds + let a = &poly_A_vec[i]; + let b = &poly_B_vec[i]; + + for (l, polyname) in [(a.len(), "poly_A_vec"), (b.len(), "poly_B_vec")].iter() { + assert_eq!( + *l, expected_size, + "Mismatch in size for {} at index {}", + polyname, i + ); + } + } + + let num_rounds_max = *num_rounds.iter().max().unwrap(); + let mut e = zip_with!( + iter, + (claims, num_rounds, coeffs), + |claim, num_rounds, coeff| { + let scaled_claim = E::Scalar::from((1 << (num_rounds_max - num_rounds)) as u64) * claim; + scaled_claim * coeff + } + ) + .sum(); let mut r: Vec = Vec::new(); let mut quad_polys: Vec> = Vec::new(); - for _ in 0..num_rounds { - let evals: Vec<(E::Scalar, E::Scalar)> = poly_A_vec - .par_iter() - .zip(poly_B_vec.par_iter()) - .map(|(poly_A, poly_B)| Self::compute_eval_points_quad(poly_A, poly_B, &comb_func)) - .collect(); + for current_round in 0..num_rounds_max { + let remaining_rounds = num_rounds_max - current_round; + let evals: Vec<(E::Scalar, E::Scalar)> = zip_with!( + par_iter, + (num_rounds, claims, poly_A_vec, poly_B_vec), + |num_rounds, claim, poly_A, poly_B| { + if remaining_rounds <= *num_rounds { + Self::compute_eval_points_quad(poly_A, poly_B, &comb_func) + } else { + let remaining_variables = remaining_rounds - num_rounds - 1; + let scaled_claim = E::Scalar::from((1 << remaining_variables) as u64) * claim; + (scaled_claim, scaled_claim) + } + } + ) + .collect(); let evals_combined_0 = (0..evals.len()).map(|i| evals[i].0 * coeffs[i]).sum(); let evals_combined_2 = (0..evals.len()).map(|i| evals[i].1 * coeffs[i]).sum(); @@ -176,22 +252,45 @@ impl SumcheckProof { r.push(r_i); // bound all tables to the verifier's challenge - poly_A_vec - .par_iter_mut() - .zip(poly_B_vec.par_iter_mut()) - .for_each(|(poly_A, poly_B)| { - let _ = rayon::join( - || poly_A.bind_poly_var_top(&r_i), - || poly_B.bind_poly_var_top(&r_i), - ); - }); + zip_with_for_each!( + ( + num_rounds.par_iter(), + poly_A_vec.par_iter_mut(), + poly_B_vec.par_iter_mut() + ), + |num_rounds, poly_A, poly_B| { + if remaining_rounds <= *num_rounds { + let _ = rayon::join( + || poly_A.bind_poly_var_top(&r_i), + || poly_B.bind_poly_var_top(&r_i), + ); + } + } + ); e = poly.evaluate(&r_i); quad_polys.push(poly.compress()); } + poly_A_vec.iter().for_each(|p| assert_eq!(p.len(), 1)); + poly_B_vec.iter().for_each(|p| assert_eq!(p.len(), 1)); + + let poly_A_final = poly_A_vec + .into_iter() + .map(|poly| poly[0]) + .collect::>(); + let poly_B_final = poly_B_vec + .into_iter() + .map(|poly| poly[0]) + .collect::>(); + + let eval_expected = zip_with!( + iter, + (poly_A_final, poly_B_final, coeffs), + |eA, eB, coeff| comb_func(eA, eB) * coeff + ) + .sum::(); + assert_eq!(e, eval_expected); - let poly_A_final = (0..poly_A_vec.len()).map(|i| poly_A_vec[i][0]).collect(); - let poly_B_final = (0..poly_B_vec.len()).map(|i| poly_B_vec[i][0]).collect(); let claims_prod = (poly_A_final, poly_B_final); Ok((SumcheckProof::new(quad_polys), r, claims_prod)) @@ -360,4 +459,153 @@ impl SumcheckProof { vec![poly_A[0], poly_B[0], poly_C[0], poly_D[0]], )) } + + pub fn prove_cubic_with_additive_term_batch( + claims: &[E::Scalar], + num_rounds: &[usize], + mut poly_A_vec: Vec>, + mut poly_B_vec: Vec>, + mut poly_C_vec: Vec>, + mut poly_D_vec: Vec>, + coeffs: &[E::Scalar], + comb_func: F, + transcript: &mut E::TE, + ) -> Result<(Self, Vec, Vec>), NovaError> + where + F: Fn(&E::Scalar, &E::Scalar, &E::Scalar, &E::Scalar) -> E::Scalar + Sync, + { + let num_instances = claims.len(); + assert_eq!(num_rounds.len(), num_instances); + assert_eq!(coeffs.len(), num_instances); + assert_eq!(poly_A_vec.len(), num_instances); + assert_eq!(poly_B_vec.len(), num_instances); + assert_eq!(poly_C_vec.len(), num_instances); + assert_eq!(poly_D_vec.len(), num_instances); + + for (i, &num_rounds) in num_rounds.iter().enumerate() { + let expected_size = 1 << num_rounds; + + // Direct indexing with the assumption that the index will always be in bounds + let a = &poly_A_vec[i]; + let b = &poly_B_vec[i]; + let c = &poly_C_vec[i]; + let d = &poly_D_vec[i]; + + for (l, polyname) in [ + (a.len(), "poly_A"), + (b.len(), "poly_B"), + (c.len(), "poly_C"), + (d.len(), "poly_D"), + ] + .iter() + { + assert_eq!( + *l, expected_size, + "Mismatch in size for {} at index {}", + polyname, i + ); + } + } + + let num_rounds_max = *num_rounds.iter().max().unwrap(); + + let mut r: Vec = Vec::new(); + let mut polys: Vec> = Vec::new(); + let mut claim_per_round = zip_with!( + iter, + (claims, num_rounds, coeffs), + |claim, num_rounds, coeff| { + let scaled_claim = E::Scalar::from((1 << (num_rounds_max - num_rounds)) as u64) * claim; + scaled_claim * *coeff + } + ) + .sum(); + + for current_round in 0..num_rounds_max { + let remaining_rounds = num_rounds_max - current_round; + let evals: Vec<(E::Scalar, E::Scalar, E::Scalar)> = zip_with!( + par_iter, + (num_rounds, claims, poly_A_vec, poly_B_vec, poly_C_vec, poly_D_vec), + |num_rounds, claim, poly_A, poly_B, poly_C, poly_D| { + if remaining_rounds <= *num_rounds { + Self::compute_eval_points_cubic_with_additive_term( + poly_A, poly_B, poly_C, poly_D, &comb_func, + ) + } else { + let remaining_variables = remaining_rounds - num_rounds - 1; + let scaled_claim = E::Scalar::from((1 << remaining_variables) as u64) * claim; + (scaled_claim, scaled_claim, scaled_claim) + } + } + ) + .collect(); + + let evals_combined_0 = (0..num_instances).map(|i| evals[i].0 * coeffs[i]).sum(); + let evals_combined_2 = (0..num_instances).map(|i| evals[i].1 * coeffs[i]).sum(); + let evals_combined_3 = (0..num_instances).map(|i| evals[i].2 * coeffs[i]).sum(); + + let evals = vec![ + evals_combined_0, + claim_per_round - evals_combined_0, + evals_combined_2, + evals_combined_3, + ]; + let poly = UniPoly::from_evals(&evals); + + // append the prover's message to the transcript + transcript.absorb(b"p", &poly); + + //derive the verifier's challenge for the next round + let r_i = transcript.squeeze(b"c")?; + r.push(r_i); + + polys.push(poly.compress()); + + // Set up next round + claim_per_round = poly.evaluate(&r_i); + + // bound all the tables to the verifier's challenge + + zip_with_for_each!( + ( + num_rounds.par_iter(), + poly_A_vec.par_iter_mut(), + poly_B_vec.par_iter_mut(), + poly_C_vec.par_iter_mut(), + poly_D_vec.par_iter_mut() + ), + |num_rounds, poly_A, poly_B, poly_C, poly_D| { + if remaining_rounds <= *num_rounds { + let _ = rayon::join( + || { + rayon::join( + || poly_A.bind_poly_var_top(&r_i), + || poly_B.bind_poly_var_top(&r_i), + ) + }, + || { + rayon::join( + || poly_C.bind_poly_var_top(&r_i), + || poly_D.bind_poly_var_top(&r_i), + ) + }, + ); + } + } + ); + } + + let poly_A_final = poly_A_vec.into_iter().map(|poly| poly[0]).collect(); + let poly_B_final = poly_B_vec.into_iter().map(|poly| poly[0]).collect(); + let poly_C_final = poly_C_vec.into_iter().map(|poly| poly[0]).collect(); + let poly_D_final = poly_D_vec.into_iter().map(|poly| poly[0]).collect(); + + Ok(( + SumcheckProof { + compressed_polys: polys, + }, + r, + vec![poly_A_final, poly_B_final, poly_C_final, poly_D_final], + )) + } } diff --git a/src/supernova/circuit.rs b/src/supernova/circuit.rs new file mode 100644 index 000000000..21bd771e8 --- /dev/null +++ b/src/supernova/circuit.rs @@ -0,0 +1,749 @@ +//! Supernova implemetation support arbitrary argumented circuits and running instances. +//! There are two Verification Circuits for each argumented circuit: The primary and the secondary. +//! Each of them is over a Pasta curve but +//! only the primary executes the next step of the computation. +//! Each circuit takes as input 2 hashes. +//! Each circuit folds the last invocation of the other into the respective running instance, specified by `augmented_circuit_index` +//! +//! The augmented circuit F' for `SuperNova` that includes everything from Nova +//! and additionally checks: +//! 1. Ui[] are contained in X[0] hash pre-image. +//! 2. R1CS Instance u is folded into Ui[augmented_circuit_index] correctly; just like Nova IVC. +//! 3. (optional by F logic) F circuit might check `program_counter_{i}` invoked current F circuit is legal or not. +//! 3. F circuit produce `program_counter_{i+1}` and sent to next round for optionally constraint the next F' argumented circuit. + +use crate::{ + constants::NUM_HASH_BITS, + gadgets::{ + ecc::AllocatedPoint, + r1cs::{ + conditionally_select_alloc_relaxed_r1cs, + conditionally_select_vec_allocated_relaxed_r1cs_instance, AllocatedR1CSInstance, + AllocatedRelaxedR1CSInstance, + }, + utils::{ + alloc_num_equals, alloc_scalar_as_base, alloc_zero, conditionally_select_vec, le_bits_to_num, + }, + }, + r1cs::{R1CSInstance, RelaxedR1CSInstance}, + traits::{ + circuit_supernova::EnforcingStepCircuit, commitment::CommitmentTrait, Engine, ROCircuitTrait, + ROConstantsCircuit, + }, + Commitment, +}; +use bellpepper_core::{ + boolean::{AllocatedBit, Boolean}, + num::AllocatedNum, + ConstraintSystem, SynthesisError, +}; + +use bellpepper::gadgets::Assignment; + +use ff::Field; +use itertools::Itertools as _; +use serde::{Deserialize, Serialize}; + +use crate::supernova::{ + num_ro_inputs, + utils::{get_from_vec_alloc_relaxed_r1cs, get_selector_vec_from_index}, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SuperNovaAugmentedCircuitParams { + limb_width: usize, + n_limbs: usize, + is_primary_circuit: bool, // A boolean indicating if this is the primary circuit +} + +impl SuperNovaAugmentedCircuitParams { + pub const fn new(limb_width: usize, n_limbs: usize, is_primary_circuit: bool) -> Self { + Self { + limb_width, + n_limbs, + is_primary_circuit, + } + } + + pub fn get_n_limbs(&self) -> usize { + self.n_limbs + } +} + +#[derive(Debug)] +pub struct SuperNovaAugmentedCircuitInputs<'a, E: Engine> { + pp_digest: E::Scalar, + i: E::Base, + /// Input to the circuit for the base case + z0: &'a [E::Base], + /// Input to the circuit for the non-base case + zi: Option<&'a [E::Base]>, + /// List of `RelaxedR1CSInstance`. + /// `None` if this is the base case. + /// Elements are `None` if the circuit at that index was not yet executed. + U: Option<&'a [Option>]>, + /// R1CS proof to be folded into U + u: Option<&'a R1CSInstance>, + /// Nova folding proof for accumulating u into U[j] + T: Option<&'a Commitment>, + /// Index of the current circuit + program_counter: Option, + /// Index j of circuit being folded into U[j] + last_augmented_circuit_index: E::Base, +} + +impl<'a, E: Engine> SuperNovaAugmentedCircuitInputs<'a, E> { + /// Create new inputs/witness for the verification circuit + pub fn new( + pp_digest: E::Scalar, + i: E::Base, + z0: &'a [E::Base], + zi: Option<&'a [E::Base]>, + U: Option<&'a [Option>]>, + u: Option<&'a R1CSInstance>, + T: Option<&'a Commitment>, + program_counter: Option, + last_augmented_circuit_index: E::Base, + ) -> Self { + Self { + pp_digest, + i, + z0, + zi, + U, + u, + T, + program_counter, + last_augmented_circuit_index, + } + } +} + +/// The augmented circuit F' in `SuperNova` that includes a step circuit F +/// and the circuit for the verifier in `SuperNova`'s non-interactive folding scheme, +/// `SuperNova` NIFS will fold strictly r1cs instance u with respective relaxed r1cs instance `U[last_augmented_circuit_index]` +pub struct SuperNovaAugmentedCircuit<'a, E: Engine, SC: EnforcingStepCircuit> { + params: &'a SuperNovaAugmentedCircuitParams, + ro_consts: ROConstantsCircuit, + inputs: Option>, + step_circuit: &'a SC, // The function that is applied for each step + num_augmented_circuits: usize, // number of overall augmented circuits +} + +impl<'a, E: Engine, SC: EnforcingStepCircuit> SuperNovaAugmentedCircuit<'a, E, SC> { + /// Create a new verification circuit for the input relaxed r1cs instances + pub const fn new( + params: &'a SuperNovaAugmentedCircuitParams, + inputs: Option>, + step_circuit: &'a SC, + ro_consts: ROConstantsCircuit, + num_augmented_circuits: usize, + ) -> Self { + Self { + params, + inputs, + step_circuit, + ro_consts, + num_augmented_circuits, + } + } + + /// Allocate all witnesses from the augmented function's non-deterministic inputs. + /// Optional entries are allocated as their default values. + fn alloc_witness::Base>>( + &self, + mut cs: CS, + arity: usize, + num_augmented_circuits: usize, + ) -> Result< + ( + AllocatedNum, + AllocatedNum, + Vec>, + Vec>, + Vec>, + AllocatedR1CSInstance, + AllocatedPoint, + Option>, + Vec, + ), + SynthesisError, + > { + let last_augmented_circuit_index = + AllocatedNum::alloc(cs.namespace(|| "last_augmented_circuit_index"), || { + Ok(self.inputs.get()?.last_augmented_circuit_index) + })?; + + // Allocate the params + let params = alloc_scalar_as_base::( + cs.namespace(|| "params"), + self.inputs.as_ref().map(|inputs| inputs.pp_digest), + )?; + + // Allocate i + let i = AllocatedNum::alloc(cs.namespace(|| "i"), || Ok(self.inputs.get()?.i))?; + + // Allocate program_counter only on primary circuit + let program_counter = if self.params.is_primary_circuit { + Some(AllocatedNum::alloc( + cs.namespace(|| "program_counter"), + || { + Ok( + self + .inputs + .get()? + .program_counter + .expect("program_counter missing"), + ) + }, + )?) + } else { + None + }; + + // Allocate z0 + let z_0 = (0..arity) + .map(|i| { + AllocatedNum::alloc(cs.namespace(|| format!("z0_{i}")), || { + Ok(self.inputs.get()?.z0[i]) + }) + }) + .collect::>, _>>()?; + + // Allocate zi. If inputs.zi is not provided (base case) allocate default value 0 + let zero = vec![E::Base::ZERO; arity]; + let z_i = (0..arity) + .map(|i| { + AllocatedNum::alloc(cs.namespace(|| format!("zi_{i}")), || { + Ok(self.inputs.get()?.zi.unwrap_or(&zero)[i]) + }) + }) + .collect::>, _>>()?; + + // Allocate the running instances + let U = (0..num_augmented_circuits) + .map(|i| { + AllocatedRelaxedR1CSInstance::alloc( + cs.namespace(|| format!("Allocate U {:?}", i)), + self + .inputs + .as_ref() + .and_then(|inputs| inputs.U.and_then(|U| U[i].as_ref())), + self.params.limb_width, + self.params.n_limbs, + ) + }) + .collect::>, _>>()?; + + // Allocate the r1cs instance to be folded in + let u = AllocatedR1CSInstance::alloc( + cs.namespace(|| "allocate instance u to fold"), + self.inputs.as_ref().and_then(|inputs| inputs.u), + )?; + + // Allocate T + let T = AllocatedPoint::alloc( + cs.namespace(|| "allocate T"), + self + .inputs + .as_ref() + .and_then(|inputs| inputs.T.map(|T| T.to_coordinates())), + )?; + T.check_on_curve(cs.namespace(|| "check T on curve"))?; + + // Compute instance selector + let last_augmented_circuit_selector = get_selector_vec_from_index( + cs.namespace(|| "instance selector"), + &last_augmented_circuit_index, + num_augmented_circuits, + )?; + + Ok(( + params, + i, + z_0, + z_i, + U, + u, + T, + program_counter, + last_augmented_circuit_selector, + )) + } + + /// Synthesizes base case and returns the new relaxed `R1CSInstance` + fn synthesize_base_case::Base>>( + &self, + mut cs: CS, + u: AllocatedR1CSInstance, + last_augmented_circuit_selector: &[Boolean], + ) -> Result>, SynthesisError> { + let mut cs = cs.namespace(|| "alloc U_i default"); + + // Allocate a default relaxed r1cs instance + let default = AllocatedRelaxedR1CSInstance::default( + cs.namespace(|| "Allocate primary U_default".to_string()), + self.params.limb_width, + self.params.n_limbs, + )?; + + // The primary circuit just initialize single AllocatedRelaxedR1CSInstance + let U_default = if self.params.is_primary_circuit { + vec![default] + } else { + // The secondary circuit convert the incoming R1CS instance on index which match last_augmented_circuit_index + let incoming_r1cs = AllocatedRelaxedR1CSInstance::from_r1cs_instance( + cs.namespace(|| "Allocate incoming_r1cs"), + u, + self.params.limb_width, + self.params.n_limbs, + )?; + + last_augmented_circuit_selector + .iter() + .enumerate() + .map(|(i, equal_bit)| { + // If index match last_augmented_circuit_index, then return incoming_r1cs, + // otherwise return the default one + conditionally_select_alloc_relaxed_r1cs( + cs.namespace(|| format!("select on index namespace {:?}", i)), + &incoming_r1cs, + &default, + equal_bit, + ) + }) + .collect::>, _>>()? + }; + Ok(U_default) + } + + /// Synthesizes non base case and returns the new relaxed `R1CSInstance` + /// And a boolean indicating if all checks pass + fn synthesize_non_base_case::Base>>( + &self, + mut cs: CS, + params: &AllocatedNum, + i: &AllocatedNum, + z_0: &[AllocatedNum], + z_i: &[AllocatedNum], + U: &[AllocatedRelaxedR1CSInstance], + u: &AllocatedR1CSInstance, + T: &AllocatedPoint, + arity: usize, + last_augmented_circuit_selector: &[Boolean], + program_counter: &Option>, + ) -> Result<(Vec>, AllocatedBit), SynthesisError> { + // Check that u.x[0] = Hash(params, i, program_counter, z0, zi, U[]) + let mut ro = E::ROCircuit::new( + self.ro_consts.clone(), + num_ro_inputs( + self.num_augmented_circuits, + self.params.get_n_limbs(), + arity, + self.params.is_primary_circuit, + ), + ); + ro.absorb(params); + ro.absorb(i); + + if self.params.is_primary_circuit { + if let Some(program_counter) = program_counter.as_ref() { + ro.absorb(program_counter) + } else { + Err(SynthesisError::AssignmentMissing)? + } + } + + for e in z_0 { + ro.absorb(e); + } + for e in z_i { + ro.absorb(e); + } + + U.iter().enumerate().try_for_each(|(i, U)| { + U.absorb_in_ro(cs.namespace(|| format!("absorb U {:?}", i)), &mut ro) + })?; + + let hash_bits = ro.squeeze(cs.namespace(|| "Input hash"), NUM_HASH_BITS)?; + let hash = le_bits_to_num(cs.namespace(|| "bits to hash"), &hash_bits)?; + let check_pass: AllocatedBit = alloc_num_equals( + cs.namespace(|| "check consistency of u.X[0] with H(params, U, i, z0, zi)"), + &u.X0, + &hash, + )?; + + // Run NIFS Verifier + let U_to_fold = get_from_vec_alloc_relaxed_r1cs( + cs.namespace(|| "U to fold"), + U, + last_augmented_circuit_selector, + )?; + let U_fold = U_to_fold.fold_with_r1cs( + cs.namespace(|| "compute fold of U and u"), + params, + u, + T, + self.ro_consts.clone(), + self.params.limb_width, + self.params.n_limbs, + )?; + + // update AllocatedRelaxedR1CSInstance on index match augmented circuit index + let U_next: Vec> = U + .iter() + .zip_eq(last_augmented_circuit_selector.iter()) + .map(|(U, equal_bit)| { + conditionally_select_alloc_relaxed_r1cs( + cs.namespace(|| "select on index namespace"), + &U_fold, + U, + equal_bit, + ) + }) + .collect::>, _>>()?; + + Ok((U_next, check_pass)) + } + + pub fn synthesize::Base>>( + self, + cs: &mut CS, + ) -> Result<(Option>, Vec>), SynthesisError> { + let arity = self.step_circuit.arity(); + let num_augmented_circuits = if self.params.is_primary_circuit { + // primary circuit only fold single running instance with secondary output strict r1cs instance + 1 + } else { + // secondary circuit contains the logic to choose one of multiple augments running instance to fold + self.num_augmented_circuits + }; + + if self.inputs.is_some() { + // Check arity of z0 + let z0_len = self.inputs.as_ref().map_or(0, |inputs| inputs.z0.len()); + if self.step_circuit.arity() != z0_len { + return Err(SynthesisError::IncompatibleLengthVector(format!( + "z0_len {:?} != arity lengh {:?}", + z0_len, + self.step_circuit.arity() + ))); + } + + // The primary curve should always fold the circuit with index 0 + let last_augmented_circuit_index = self + .inputs + .get() + .map_or(E::Base::ZERO, |inputs| inputs.last_augmented_circuit_index); + if self.params.is_primary_circuit && last_augmented_circuit_index != E::Base::ZERO { + return Err(SynthesisError::IncompatibleLengthVector( + "primary circuit running instance only valid on index 0".to_string(), + )); + } + } + + // Allocate witnesses + let (params, i, z_0, z_i, U, u, T, program_counter, last_augmented_circuit_selector) = self + .alloc_witness( + cs.namespace(|| "allocate the circuit witness"), + arity, + num_augmented_circuits, + )?; + + // Compute variable indicating if this is the base case + let zero = alloc_zero(cs.namespace(|| "zero")); + let is_base_case = alloc_num_equals(cs.namespace(|| "Check if base case"), &i.clone(), &zero)?; + + // Synthesize the circuit for the non-base case and get the new running + // instances along with a boolean indicating if all checks have passed + // must use return `last_augmented_circuit_index_checked` since it got range checked + let (U_next_non_base, check_non_base_pass) = self.synthesize_non_base_case( + cs.namespace(|| "synthesize non base case"), + ¶ms, + &i, + &z_0, + &z_i, + &U, + &u, + &T, + arity, + &last_augmented_circuit_selector, + &program_counter, + )?; + + // Synthesize the circuit for the base case and get the new running instances + let U_next_base = self.synthesize_base_case( + cs.namespace(|| "base case"), + u.clone(), + &last_augmented_circuit_selector, + )?; + + // Either check_non_base_pass=true or we are in the base case + let should_be_false = AllocatedBit::nor( + cs.namespace(|| "check_non_base_pass nor base_case"), + &check_non_base_pass, + &is_base_case, + )?; + cs.enforce( + || "check_non_base_pass nor base_case = false", + |lc| lc + should_be_false.get_variable(), + |lc| lc + CS::one(), + |lc| lc, + ); + + // Compute the U_next + let U_next = conditionally_select_vec_allocated_relaxed_r1cs_instance( + cs.namespace(|| "U_next"), + &U_next_base[..], + &U_next_non_base[..], + &Boolean::from(is_base_case.clone()), + )?; + + // Compute i + 1 + let i_next = AllocatedNum::alloc(cs.namespace(|| "i + 1"), || { + Ok(*i.get_value().get()? + E::Base::ONE) + })?; + cs.enforce( + || "check i + 1", + |lc| lc + i.get_variable() + CS::one(), + |lc| lc + CS::one(), + |lc| lc + i_next.get_variable(), + ); + + // Compute z_{i+1} + let z_input = conditionally_select_vec( + cs.namespace(|| "select input to F"), + &z_0, + &z_i, + &Boolean::from(is_base_case), + )?; + + let (program_counter_new, z_next) = self.step_circuit.enforcing_synthesize( + &mut cs.namespace(|| "F"), + program_counter.as_ref(), + &z_input, + )?; + + if z_next.len() != arity { + return Err(SynthesisError::IncompatibleLengthVector( + "z_next".to_string(), + )); + } + + // To check correct folding sequencing we are just going to make a hash. + // The next RunningInstance folding can take the pre-image of this hash as witness and check. + + // "Finally, there is a subtle sizing issue in the above description: in each step, + // because Ui+1 is produced as the public IO of F0 program_counter+1, it must be contained in + // the public IO of instance ui+1. In the next iteration, because ui+1 is folded + // into Ui+1[program_counter+1], this means that Ui+1[program_counter+1] is at least as large as Ui by the + // properties of the folding scheme. This means that the list of running instances + // grows in each step. To alleviate this issue, we have each F0j only produce a hash + // of its outputs as public output. In the subsequent step, the next augmented + // function takes as non-deterministic input a preimage to this hash." pg.16 + + // https://eprint.iacr.org/2022/1758.pdf + + // Compute the new hash H(params, i+1, program_counter, z0, z_{i+1}, U_next) + let mut ro = E::ROCircuit::new( + self.ro_consts.clone(), + num_ro_inputs( + self.num_augmented_circuits, + self.params.get_n_limbs(), + self.step_circuit.arity(), + self.params.is_primary_circuit, + ), + ); + ro.absorb(¶ms); + ro.absorb(&i_next); + // optionally absorb program counter if exist + if program_counter.is_some() { + ro.absorb( + program_counter_new + .as_ref() + .expect("new program counter missing"), + ) + } + for e in &z_0 { + ro.absorb(e); + } + for e in &z_next { + ro.absorb(e); + } + U_next.iter().enumerate().try_for_each(|(i, U)| { + U.absorb_in_ro(cs.namespace(|| format!("absorb U_new {:?}", i)), &mut ro) + })?; + + let hash_bits = ro.squeeze(cs.namespace(|| "output hash bits"), NUM_HASH_BITS)?; + let hash = le_bits_to_num(cs.namespace(|| "convert hash to num"), &hash_bits)?; + + // We are cycling of curve implementation, so primary/secondary will rotate hash in IO for the others to check + // bypass unmodified hash of other circuit as next X[0] + // and output the computed the computed hash as next X[1] + u.X1 + .inputize(cs.namespace(|| "bypass unmodified hash of the other circuit"))?; + hash.inputize(cs.namespace(|| "output new hash of this circuit"))?; + + Ok((program_counter_new, z_next)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + bellpepper::{ + r1cs::{NovaShape, NovaWitness}, + solver::SatisfyingAssignment, + test_shape_cs::TestShapeCS, + }, + constants::{BN_LIMB_WIDTH, BN_N_LIMBS}, + gadgets::utils::scalar_as_base, + provider::{ + poseidon::PoseidonConstantsCircuit, + {Bn256Engine, GrumpkinEngine}, {PallasEngine, VestaEngine}, + {Secp256k1Engine, Secq256k1Engine}, + }, + traits::{circuit_supernova::TrivialTestCircuit, snark::default_ck_hint}, + }; + + // In the following we use 1 to refer to the primary, and 2 to refer to the secondary circuit + fn test_supernova_recursive_circuit_with( + primary_params: &SuperNovaAugmentedCircuitParams, + secondary_params: &SuperNovaAugmentedCircuitParams, + ro_consts1: ROConstantsCircuit, + ro_consts2: ROConstantsCircuit, + num_constraints_primary: usize, + num_constraints_secondary: usize, + num_augmented_circuits: usize, + ) where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + { + let tc1 = TrivialTestCircuit::default(); + // Initialize the shape and ck for the primary + let circuit1: SuperNovaAugmentedCircuit<'_, E2, TrivialTestCircuit<::Base>> = + SuperNovaAugmentedCircuit::new( + primary_params, + None, + &tc1, + ro_consts1.clone(), + num_augmented_circuits, + ); + let mut cs: TestShapeCS = TestShapeCS::new(); + let _ = circuit1.synthesize(&mut cs); + let (shape1, ck1) = cs.r1cs_shape_and_key(&*default_ck_hint()); + assert_eq!(cs.num_constraints(), num_constraints_primary); + + let tc2 = TrivialTestCircuit::default(); + // Initialize the shape and ck for the secondary + let circuit2: SuperNovaAugmentedCircuit<'_, E1, TrivialTestCircuit<::Base>> = + SuperNovaAugmentedCircuit::new( + secondary_params, + None, + &tc2, + ro_consts2.clone(), + num_augmented_circuits, + ); + let mut cs: TestShapeCS = TestShapeCS::new(); + let _ = circuit2.synthesize(&mut cs); + let (shape2, ck2) = cs.r1cs_shape_and_key(&*default_ck_hint()); + assert_eq!(cs.num_constraints(), num_constraints_secondary); + + // Execute the base case for the primary + let zero1 = <::Base as Field>::ZERO; + let mut cs1 = SatisfyingAssignment::::new(); + let vzero1 = vec![zero1]; + let inputs1: SuperNovaAugmentedCircuitInputs<'_, E2> = SuperNovaAugmentedCircuitInputs::new( + scalar_as_base::(zero1), // pass zero for testing + zero1, + &vzero1, + None, + None, + None, + None, + Some(zero1), + zero1, + ); + let circuit1: SuperNovaAugmentedCircuit<'_, E2, TrivialTestCircuit<::Base>> = + SuperNovaAugmentedCircuit::new( + primary_params, + Some(inputs1), + &tc1, + ro_consts1, + num_augmented_circuits, + ); + let _ = circuit1.synthesize(&mut cs1); + let (inst1, witness1) = cs1.r1cs_instance_and_witness(&shape1, &ck1).unwrap(); + // Make sure that this is satisfiable + assert!(shape1.is_sat(&ck1, &inst1, &witness1).is_ok()); + + // Execute the base case for the secondary + let zero2 = <::Base as Field>::ZERO; + let mut cs2 = SatisfyingAssignment::::new(); + let vzero2 = vec![zero2]; + let inputs2: SuperNovaAugmentedCircuitInputs<'_, E1> = SuperNovaAugmentedCircuitInputs::new( + scalar_as_base::(zero2), // pass zero for testing + zero2, + &vzero2, + None, + None, + Some(&inst1), + None, + Some(zero2), + zero2, + ); + let circuit2: SuperNovaAugmentedCircuit<'_, E1, TrivialTestCircuit<::Base>> = + SuperNovaAugmentedCircuit::new( + secondary_params, + Some(inputs2), + &tc2, + ro_consts2, + num_augmented_circuits, + ); + let _ = circuit2.synthesize(&mut cs2); + let (inst2, witness2) = cs2.r1cs_instance_and_witness(&shape2, &ck2).unwrap(); + // Make sure that it is satisfiable + assert!(shape2.is_sat(&ck2, &inst2, &witness2).is_ok()); + } + + #[test] + fn test_supernova_recursive_circuit_pasta() { + // this test checks against values that must be replicated in benchmarks if changed here + let params1 = SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); + let params2 = SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); + let ro_consts1: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + let ro_consts2: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + + test_supernova_recursive_circuit_with::( + ¶ms1, ¶ms2, ro_consts1, ro_consts2, 9844, 10392, 1, + ); + // TODO: extend to num_augmented_circuits >= 2 + } + + #[test] + fn test_supernova_recursive_circuit_grumpkin() { + let params1 = SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); + let params2 = SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); + let ro_consts1: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + let ro_consts2: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + + test_supernova_recursive_circuit_with::( + ¶ms1, ¶ms2, ro_consts1, ro_consts2, 10012, 10581, 1, + ); + // TODO: extend to num_augmented_circuits >= 2 + } + + #[test] + fn test_supernova_recursive_circuit_secp() { + let params1 = SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); + let params2 = SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); + let ro_consts1: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + let ro_consts2: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + + test_supernova_recursive_circuit_with::( + ¶ms1, ¶ms2, ro_consts1, ro_consts2, 10291, 11004, 1, + ); + // TODO: extend to num_augmented_circuits >= 2 + } +} diff --git a/src/supernova/error.rs b/src/supernova/error.rs new file mode 100644 index 000000000..b24db047a --- /dev/null +++ b/src/supernova/error.rs @@ -0,0 +1,19 @@ +//! This module defines errors returned by the library. +use core::fmt::Debug; +use thiserror::Error; + +use crate::errors::NovaError; + +/// Errors returned by Nova +#[derive(Clone, Debug, Eq, PartialEq, Error)] +pub enum SuperNovaError { + /// Nova error + #[error("NovaError")] + NovaError(#[from] NovaError), + /// missig commitment key + #[error("MissingCK")] + MissingCK, + /// Extended error for supernova + #[error("UnSatIndex")] + UnSatIndex(&'static str, usize), +} diff --git a/src/supernova/mod.rs b/src/supernova/mod.rs new file mode 100644 index 000000000..f2c9a5d3e --- /dev/null +++ b/src/supernova/mod.rs @@ -0,0 +1,1082 @@ +#![doc = include_str!("../../notes/supernova.md")] + +use std::marker::PhantomData; +use std::ops::Index; + +use crate::{ + bellpepper::shape_cs::ShapeCS, + constants::{BN_LIMB_WIDTH, BN_N_LIMBS, NUM_HASH_BITS}, + digest::{DigestComputer, SimpleDigestible}, + errors::NovaError, + r1cs::{ + commitment_key_size, CommitmentKeyHint, R1CSInstance, R1CSShape, R1CSWitness, + RelaxedR1CSInstance, RelaxedR1CSWitness, + }, + scalar_as_base, + traits::{ + circuit_supernova::StepCircuit, + commitment::{CommitmentEngineTrait, CommitmentTrait}, + AbsorbInROTrait, Engine, ROConstants, ROConstantsCircuit, ROTrait, + }, + CircuitShape, Commitment, CommitmentKey, +}; +use ff::Field; +use itertools::Itertools as _; +use once_cell::sync::OnceCell; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use log::debug; + +use crate::bellpepper::{ + r1cs::{NovaShape, NovaWitness}, + solver::SatisfyingAssignment, +}; +use bellpepper_core::ConstraintSystem; + +use crate::nifs::NIFS; + +mod circuit; // declare the module first +use circuit::{ + SuperNovaAugmentedCircuit, SuperNovaAugmentedCircuitInputs, SuperNovaAugmentedCircuitParams, +}; + +use self::error::SuperNovaError; + +/// A struct that manages all the digests of the primary circuits of a SuperNova instance +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct CircuitDigests { + digests: Vec, +} + +impl SimpleDigestible for CircuitDigests {} + +impl std::ops::Deref for CircuitDigests { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.digests + } +} + +impl CircuitDigests { + /// Construct a new [CircuitDigests] + pub fn new(digests: Vec) -> Self { + CircuitDigests { digests } + } + + /// Return the [CircuitDigests]' digest. + pub fn digest(&self) -> E::Scalar { + let dc: DigestComputer<'_, ::Scalar, CircuitDigests> = + DigestComputer::new(self); + dc.digest().expect("Failure in computing digest") + } +} + +/// A vector of [CircuitParams] corresponding to a set of [PublicParams] +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct PublicParams +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ + /// The internal circuit shapes + pub circuit_shapes: Vec>, + + ro_consts_primary: ROConstants, + ro_consts_circuit_primary: ROConstantsCircuit, + ck_primary: CommitmentKey, // This is shared between all circuit params + augmented_circuit_params_primary: SuperNovaAugmentedCircuitParams, + + ro_consts_secondary: ROConstants, + ro_consts_circuit_secondary: ROConstantsCircuit, + ck_secondary: CommitmentKey, + circuit_shape_secondary: CircuitShape, + augmented_circuit_params_secondary: SuperNovaAugmentedCircuitParams, + + /// Digest constructed from this `PublicParams`' parameters + #[serde(skip, default = "OnceCell::new")] + digest: OnceCell, + _p: PhantomData<(C1, C2)>, +} + +/// Auxilliary [PublicParams] information about the commitment keys and +/// secondary circuit. This is used as a helper struct when reconstructing +/// [PublicParams] downstream in lurk. +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct AuxParams +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, +{ + ro_consts_primary: ROConstants, + ro_consts_circuit_primary: ROConstantsCircuit, + ck_primary: CommitmentKey, // This is shared between all circuit params + augmented_circuit_params_primary: SuperNovaAugmentedCircuitParams, + + ro_consts_secondary: ROConstants, + ro_consts_circuit_secondary: ROConstantsCircuit, + ck_secondary: CommitmentKey, + circuit_shape_secondary: CircuitShape, + augmented_circuit_params_secondary: SuperNovaAugmentedCircuitParams, + + digest: E1::Scalar, +} + +impl Index for PublicParams +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ + type Output = CircuitShape; + + fn index(&self, index: usize) -> &Self::Output { + &self.circuit_shapes[index] + } +} + +impl SimpleDigestible for PublicParams +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ +} + +impl PublicParams +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ + /// Construct a new [PublicParams] + /// + /// # Note + /// + /// Public parameters set up a number of bases for the homomorphic commitment scheme of Nova. + /// + /// Some final compressing SNARKs, like variants of Spartan, use computation commitments that require + /// larger sizes for these parameters. These SNARKs provide a hint for these values by + /// implementing `RelaxedR1CSSNARKTrait::commitment_key_floor()`, which can be passed to this function. + /// + /// If you're not using such a SNARK, pass `&(|_| 0)` instead. + /// + /// # Arguments + /// + /// * `non_uniform_circuit`: The non-uniform circuit of type `NC`. + /// * `ck_hint1`: A `CommitmentKeyHint` for `E1`, which is a function that provides a hint + /// for the number of generators required in the commitment scheme for the primary circuit. + /// * `ck_hint2`: A `CommitmentKeyHint` for `E2`, similar to `ck_hint1`, but for the secondary circuit. + pub fn setup>( + non_uniform_circuit: &NC, + ck_hint1: &CommitmentKeyHint, + ck_hint2: &CommitmentKeyHint, + ) -> Self { + let num_circuits = non_uniform_circuit.num_circuits(); + + let augmented_circuit_params_primary = + SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); + let ro_consts_primary: ROConstants = ROConstants::::default(); + // ro_consts_circuit_primary are parameterized by E2 because the type alias uses E2::Base = E1::Scalar + let ro_consts_circuit_primary: ROConstantsCircuit = ROConstantsCircuit::::default(); + + let circuit_shapes = (0..num_circuits) + .map(|i| { + let c_primary = non_uniform_circuit.primary_circuit(i); + let F_arity = c_primary.arity(); + // Initialize ck for the primary + let circuit_primary: SuperNovaAugmentedCircuit<'_, E2, C1> = SuperNovaAugmentedCircuit::new( + &augmented_circuit_params_primary, + None, + &c_primary, + ro_consts_circuit_primary.clone(), + num_circuits, + ); + let mut cs: ShapeCS = ShapeCS::new(); + circuit_primary + .synthesize(&mut cs) + .expect("circuit synthesis failed"); + + // We use the largest commitment_key for all instances + let r1cs_shape_primary = cs.r1cs_shape(); + CircuitShape::new(r1cs_shape_primary, F_arity) + }) + .collect::>(); + + let ck_primary = Self::compute_primary_ck(&circuit_shapes, ck_hint1); + + let augmented_circuit_params_secondary = + SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); + let ro_consts_secondary: ROConstants = ROConstants::::default(); + let c_secondary = non_uniform_circuit.secondary_circuit(); + let F_arity_secondary = c_secondary.arity(); + let ro_consts_circuit_secondary: ROConstantsCircuit = ROConstantsCircuit::::default(); + + let circuit_secondary: SuperNovaAugmentedCircuit<'_, E1, C2> = SuperNovaAugmentedCircuit::new( + &augmented_circuit_params_secondary, + None, + &c_secondary, + ro_consts_circuit_secondary.clone(), + num_circuits, + ); + let mut cs: ShapeCS = ShapeCS::new(); + circuit_secondary + .synthesize(&mut cs) + .expect("circuit synthesis failed"); + let (r1cs_shape_secondary, ck_secondary) = cs.r1cs_shape_and_key(ck_hint2); + let circuit_shape_secondary = CircuitShape::new(r1cs_shape_secondary, F_arity_secondary); + + let pp = PublicParams { + circuit_shapes, + ro_consts_primary, + ro_consts_circuit_primary, + ck_primary, + augmented_circuit_params_primary, + ro_consts_secondary, + ro_consts_circuit_secondary, + ck_secondary, + circuit_shape_secondary, + augmented_circuit_params_secondary, + digest: OnceCell::new(), + _p: PhantomData, + }; + + // make sure to initialize the `OnceCell` and compute the digest + // and avoid paying for unexpected performance costs later + pp.digest(); + pp + } + + /// Breaks down an instance of [PublicParams] into the circuit params and auxilliary params. + pub fn into_parts(self) -> (Vec>, AuxParams) { + let digest = self.digest(); + + let PublicParams { + circuit_shapes, + ro_consts_primary, + ro_consts_circuit_primary, + ck_primary, + augmented_circuit_params_primary, + ro_consts_secondary, + ro_consts_circuit_secondary, + ck_secondary, + circuit_shape_secondary, + augmented_circuit_params_secondary, + digest: _digest, + _p, + } = self; + + let aux_params = AuxParams { + ro_consts_primary, + ro_consts_circuit_primary, + ck_primary, + augmented_circuit_params_primary, + ro_consts_secondary, + ro_consts_circuit_secondary, + ck_secondary, + circuit_shape_secondary, + augmented_circuit_params_secondary, + digest, + }; + + (circuit_shapes, aux_params) + } + + /// Create a [PublicParams] from a vector of raw [CircuitShape] and auxilliary params. + pub fn from_parts(circuit_shapes: Vec>, aux_params: AuxParams) -> Self { + let pp = PublicParams { + circuit_shapes, + ro_consts_primary: aux_params.ro_consts_primary, + ro_consts_circuit_primary: aux_params.ro_consts_circuit_primary, + ck_primary: aux_params.ck_primary, + augmented_circuit_params_primary: aux_params.augmented_circuit_params_primary, + ro_consts_secondary: aux_params.ro_consts_secondary, + ro_consts_circuit_secondary: aux_params.ro_consts_circuit_secondary, + ck_secondary: aux_params.ck_secondary, + circuit_shape_secondary: aux_params.circuit_shape_secondary, + augmented_circuit_params_secondary: aux_params.augmented_circuit_params_secondary, + digest: OnceCell::new(), + _p: PhantomData, + }; + assert_eq!( + aux_params.digest, + pp.digest(), + "param data is invalid; aux_params contained the incorrect digest" + ); + pp + } + + /// Create a [PublicParams] from a vector of raw [CircuitShape] and auxilliary params. + /// We don't check that the `aux_params.digest` is a valid digest for the created params. + pub fn from_parts_unchecked( + circuit_shapes: Vec>, + aux_params: AuxParams, + ) -> Self { + PublicParams { + circuit_shapes, + ro_consts_primary: aux_params.ro_consts_primary, + ro_consts_circuit_primary: aux_params.ro_consts_circuit_primary, + ck_primary: aux_params.ck_primary, + augmented_circuit_params_primary: aux_params.augmented_circuit_params_primary, + ro_consts_secondary: aux_params.ro_consts_secondary, + ro_consts_circuit_secondary: aux_params.ro_consts_circuit_secondary, + ck_secondary: aux_params.ck_secondary, + circuit_shape_secondary: aux_params.circuit_shape_secondary, + augmented_circuit_params_secondary: aux_params.augmented_circuit_params_secondary, + digest: aux_params.digest.into(), + _p: PhantomData, + } + } + + /// Compute primary and secondary commitment keys sized to handle the largest of the circuits in the provided + /// `CircuitShape`. + fn compute_primary_ck( + circuit_params: &[CircuitShape], + ck_hint1: &CommitmentKeyHint, + ) -> CommitmentKey { + let size_primary = circuit_params + .iter() + .map(|circuit| commitment_key_size(&circuit.r1cs_shape, ck_hint1)) + .max() + .unwrap(); + + E1::CE::setup(b"ck", size_primary) + } + + /// Return the [PublicParams]' digest. + pub fn digest(&self) -> E1::Scalar { + self + .digest + .get_or_try_init(|| { + let dc: DigestComputer<'_, ::Scalar, PublicParams> = + DigestComputer::new(self); + dc.digest() + }) + .cloned() + .expect("Failure in retrieving digest") + } + + /// All of the primary circuit digests of this [PublicParams] + pub fn circuit_param_digests(&self) -> CircuitDigests { + let digests = self + .circuit_shapes + .iter() + .map(|cp| cp.digest()) + .collect::>(); + CircuitDigests { digests } + } + + /// Returns all the primary R1CS Shapes + pub fn primary_r1cs_shapes(&self) -> Vec<&R1CSShape> { + self + .circuit_shapes + .iter() + .map(|cs| &cs.r1cs_shape) + .collect::>() + } +} + +/// A SNARK that proves the correct execution of an non-uniform incremental computation +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct RecursiveSNARK +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, +{ + // Cached digest of the public parameters + pp_digest: E1::Scalar, + num_augmented_circuits: usize, + + // Number of iterations performed up to now + i: usize, + + // Inputs and outputs of the primary circuits + z0_primary: Vec, + zi_primary: Vec, + + // Proven circuit index, and current program counter + proven_circuit_index: usize, + program_counter: E1::Scalar, + + // Relaxed instances for the primary circuits + // Entries are `None` if the circuit has not been executed yet + r_W_primary: Vec>>, + r_U_primary: Vec>>, + + // Inputs and outputs of the secondary circuit + z0_secondary: Vec, + zi_secondary: Vec, + // Relaxed instance for the secondary circuit + r_W_secondary: RelaxedR1CSWitness, + r_U_secondary: RelaxedR1CSInstance, + // Proof for the secondary circuit to be accumulated into r_secondary in the next iteration + l_w_secondary: R1CSWitness, + l_u_secondary: R1CSInstance, +} + +impl RecursiveSNARK +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, +{ + /// iterate base step to get new instance of recursive SNARK + #[allow(clippy::too_many_arguments)] + pub fn new< + C0: NonUniformCircuit, + C1: StepCircuit, + C2: StepCircuit, + >( + pp: &PublicParams, + non_uniform_circuit: &C0, + c_primary: &C1, + c_secondary: &C2, + z0_primary: &[E1::Scalar], + z0_secondary: &[E2::Scalar], + ) -> Result { + let num_augmented_circuits = non_uniform_circuit.num_circuits(); + let circuit_index = non_uniform_circuit.initial_circuit_index(); + + // check the length of the secondary initial input + if z0_secondary.len() != pp.circuit_shape_secondary.F_arity { + return Err(SuperNovaError::NovaError( + NovaError::InvalidStepOutputLength, + )); + } + + // check the arity of all the primary circuits match the initial input length + pp.circuit_shapes.iter().try_for_each(|circuit| { + if circuit.F_arity != z0_primary.len() { + return Err(SuperNovaError::NovaError( + NovaError::InvalidStepOutputLength, + )); + } + Ok(()) + })?; + + // base case for the primary + let mut cs_primary = SatisfyingAssignment::::new(); + let program_counter = E1::Scalar::from(circuit_index as u64); + let inputs_primary: SuperNovaAugmentedCircuitInputs<'_, E2> = + SuperNovaAugmentedCircuitInputs::new( + scalar_as_base::(pp.digest()), + E1::Scalar::ZERO, + z0_primary, + None, // zi = None for basecase + None, // U = [None], since no previous proofs have been computed + None, // u = None since we are not verifying a secondary circuit + None, // T = None since there is not proof to fold + Some(program_counter), // pc = initial_program_counter for primary circuit + E1::Scalar::ZERO, // u_index is always zero for the primary circuit + ); + + let circuit_primary: SuperNovaAugmentedCircuit<'_, E2, C1> = SuperNovaAugmentedCircuit::new( + &pp.augmented_circuit_params_primary, + Some(inputs_primary), + c_primary, + pp.ro_consts_circuit_primary.clone(), + num_augmented_circuits, + ); + + let (zi_primary_pc_next, zi_primary) = + circuit_primary.synthesize(&mut cs_primary).map_err(|_err| { + NovaError::SynthesisError + })?; + if zi_primary.len() != pp[circuit_index].F_arity { + return Err(SuperNovaError::NovaError( + NovaError::InvalidStepOutputLength, + )); + } + let (u_primary, w_primary) = cs_primary + .r1cs_instance_and_witness(&pp[circuit_index].r1cs_shape, &pp.ck_primary) + .map_err(|_err| { + NovaError::SynthesisError + })?; + + // base case for the secondary + let mut cs_secondary = SatisfyingAssignment::::new(); + let u_primary_index = E2::Scalar::from(circuit_index as u64); + let inputs_secondary: SuperNovaAugmentedCircuitInputs<'_, E1> = + SuperNovaAugmentedCircuitInputs::new( + pp.digest(), + E2::Scalar::ZERO, + z0_secondary, + None, // zi = None for basecase + None, // U = Empty list of accumulators for the primary circuits + Some(&u_primary), // Proof for first iteration of current primary circuit + None, // T = None, since we just copy u_primary rather than fold it + None, // program_counter is always None for secondary circuit + u_primary_index, // index of the circuit proof u_primary + ); + let circuit_secondary: SuperNovaAugmentedCircuit<'_, E1, C2> = SuperNovaAugmentedCircuit::new( + &pp.augmented_circuit_params_secondary, + Some(inputs_secondary), + c_secondary, + pp.ro_consts_circuit_secondary.clone(), + num_augmented_circuits, + ); + let (_, zi_secondary) = circuit_secondary + .synthesize(&mut cs_secondary) + .map_err(|_| NovaError::SynthesisError)?; + if zi_secondary.len() != pp.circuit_shape_secondary.F_arity { + return Err(NovaError::InvalidStepOutputLength.into()); + } + let (u_secondary, w_secondary) = cs_secondary + .r1cs_instance_and_witness(&pp.circuit_shape_secondary.r1cs_shape, &pp.ck_secondary) + .map_err(|_| SuperNovaError::NovaError(NovaError::UnSat))?; + + // IVC proof for the primary circuit + let l_w_primary = w_primary; + let l_u_primary = u_primary; + let r_W_primary = + RelaxedR1CSWitness::from_r1cs_witness(&pp[circuit_index].r1cs_shape, &l_w_primary); + + let r_U_primary = RelaxedR1CSInstance::from_r1cs_instance( + &pp.ck_primary, + &pp[circuit_index].r1cs_shape, + &l_u_primary, + ); + + // IVC proof of the secondary circuit + let l_w_secondary = w_secondary; + let l_u_secondary = u_secondary; + + // Initialize relaxed instance/witness pair for the secondary circuit proofs + let r_W_secondary = RelaxedR1CSWitness::::default(&pp.circuit_shape_secondary.r1cs_shape); + let r_U_secondary = + RelaxedR1CSInstance::default(&pp.ck_secondary, &pp.circuit_shape_secondary.r1cs_shape); + + // Outputs of the two circuits and next program counter thus far. + let zi_primary = zi_primary + .iter() + .map(|v| v.get_value().ok_or(NovaError::SynthesisError.into())) + .collect::::Scalar>, SuperNovaError>>()?; + let zi_primary_pc_next = zi_primary_pc_next + .expect("zi_primary_pc_next missing") + .get_value() + .ok_or::(NovaError::SynthesisError.into())?; + let zi_secondary = zi_secondary + .iter() + .map(|v| v.get_value().ok_or(NovaError::SynthesisError.into())) + .collect::::Scalar>, SuperNovaError>>()?; + + // handle the base case by initialize U_next in next round + let r_W_primary_initial_list = (0..num_augmented_circuits) + .map(|i| (i == circuit_index).then(|| r_W_primary.clone())) + .collect::>>>(); + + let r_U_primary_initial_list = (0..num_augmented_circuits) + .map(|i| (i == circuit_index).then(|| r_U_primary.clone())) + .collect::>>>(); + Ok(Self { + pp_digest: pp.digest(), + num_augmented_circuits, + i: 0_usize, // after base case, next iteration start from 1 + z0_primary: z0_primary.to_vec(), + zi_primary, + + proven_circuit_index: circuit_index, + program_counter: zi_primary_pc_next, + + r_W_primary: r_W_primary_initial_list, + r_U_primary: r_U_primary_initial_list, + z0_secondary: z0_secondary.to_vec(), + zi_secondary, + r_W_secondary, + r_U_secondary, + l_w_secondary, + l_u_secondary, + }) + } + + /// executing a step of the incremental computation + #[allow(clippy::too_many_arguments)] + pub fn prove_step, C2: StepCircuit>( + &mut self, + pp: &PublicParams, + c_primary: &C1, + c_secondary: &C2, + ) -> Result<(), SuperNovaError> { + // First step was already done in the constructor + if self.i == 0 { + self.i = 1; + return Ok(()); + } + + let circuit_index = c_primary.circuit_index(); + assert_eq!(self.program_counter, E1::Scalar::from(circuit_index as u64)); + + // fold the secondary circuit's instance + let (nifs_secondary, (r_U_secondary_folded, r_W_secondary_folded)) = NIFS::prove( + &pp.ck_secondary, + &pp.ro_consts_secondary, + &scalar_as_base::(self.pp_digest), + &pp.circuit_shape_secondary.r1cs_shape, + &self.r_U_secondary, + &self.r_W_secondary, + &self.l_u_secondary, + &self.l_w_secondary, + ) + .map_err(SuperNovaError::NovaError)?; + + // clone and updated running instance on respective circuit_index + let r_U_secondary_next = r_U_secondary_folded; + let r_W_secondary_next = r_W_secondary_folded; + + // Create single-entry accumulator list for the secondary circuit to hand to SuperNovaAugmentedCircuitInputs + let r_U_secondary = vec![Some(self.r_U_secondary.clone())]; + + let mut cs_primary = SatisfyingAssignment::::new(); + let T = + Commitment::::decompress(&nifs_secondary.comm_T).map_err(SuperNovaError::NovaError)?; + let inputs_primary: SuperNovaAugmentedCircuitInputs<'_, E2> = + SuperNovaAugmentedCircuitInputs::new( + scalar_as_base::(self.pp_digest), + E1::Scalar::from(self.i as u64), + &self.z0_primary, + Some(&self.zi_primary), + Some(&r_U_secondary), + Some(&self.l_u_secondary), + Some(&T), + Some(self.program_counter), + E1::Scalar::ZERO, + ); + + let circuit_primary: SuperNovaAugmentedCircuit<'_, E2, C1> = SuperNovaAugmentedCircuit::new( + &pp.augmented_circuit_params_primary, + Some(inputs_primary), + c_primary, + pp.ro_consts_circuit_primary.clone(), + self.num_augmented_circuits, + ); + + let (zi_primary_pc_next, zi_primary) = circuit_primary + .synthesize(&mut cs_primary) + .map_err(|_| SuperNovaError::NovaError(NovaError::SynthesisError))?; + if zi_primary.len() != pp[circuit_index].F_arity { + return Err(SuperNovaError::NovaError( + NovaError::InvalidInitialInputLength, + )); + } + + let (l_u_primary, l_w_primary) = cs_primary + .r1cs_instance_and_witness(&pp[circuit_index].r1cs_shape, &pp.ck_primary) + .map_err(SuperNovaError::NovaError)?; + + // Split into `if let`/`else` statement + // to avoid `returns a value referencing data owned by closure` error on `&RelaxedR1CSInstance::default` and `RelaxedR1CSWitness::default` + let (nifs_primary, (r_U_primary_folded, r_W_primary_folded)) = match ( + self.r_U_primary.get(circuit_index), + self.r_W_primary.get(circuit_index), + ) { + (Some(Some(r_U_primary)), Some(Some(r_W_primary))) => NIFS::prove( + &pp.ck_primary, + &pp.ro_consts_primary, + &self.pp_digest, + &pp[circuit_index].r1cs_shape, + r_U_primary, + r_W_primary, + &l_u_primary, + &l_w_primary, + ) + .map_err(SuperNovaError::NovaError)?, + _ => NIFS::prove( + &pp.ck_primary, + &pp.ro_consts_primary, + &self.pp_digest, + &pp[circuit_index].r1cs_shape, + &RelaxedR1CSInstance::default(&pp.ck_primary, &pp[circuit_index].r1cs_shape), + &RelaxedR1CSWitness::default(&pp[circuit_index].r1cs_shape), + &l_u_primary, + &l_w_primary, + ) + .map_err(SuperNovaError::NovaError)?, + }; + + let mut cs_secondary = SatisfyingAssignment::::new(); + let binding = + Commitment::::decompress(&nifs_primary.comm_T).map_err(SuperNovaError::NovaError)?; + let inputs_secondary: SuperNovaAugmentedCircuitInputs<'_, E1> = + SuperNovaAugmentedCircuitInputs::new( + self.pp_digest, + E2::Scalar::from(self.i as u64), + &self.z0_secondary, + Some(&self.zi_secondary), + Some(&self.r_U_primary), + Some(&l_u_primary), + Some(&binding), + None, // pc is always None for secondary circuit + E2::Scalar::from(circuit_index as u64), + ); + + let circuit_secondary: SuperNovaAugmentedCircuit<'_, E1, C2> = SuperNovaAugmentedCircuit::new( + &pp.augmented_circuit_params_secondary, + Some(inputs_secondary), + c_secondary, + pp.ro_consts_circuit_secondary.clone(), + self.num_augmented_circuits, + ); + let (_, zi_secondary) = circuit_secondary + .synthesize(&mut cs_secondary) + .map_err(|_| SuperNovaError::NovaError(NovaError::SynthesisError))?; + if zi_secondary.len() != pp.circuit_shape_secondary.F_arity { + return Err(SuperNovaError::NovaError( + NovaError::InvalidInitialInputLength, + )); + } + + let (l_u_secondary_next, l_w_secondary_next) = cs_secondary + .r1cs_instance_and_witness(&pp.circuit_shape_secondary.r1cs_shape, &pp.ck_secondary) + .map_err(|_| SuperNovaError::NovaError(NovaError::UnSat))?; + + // update the running instances and witnesses + let zi_primary = zi_primary + .iter() + .map(|v| { + v.get_value() + .ok_or(SuperNovaError::NovaError(NovaError::SynthesisError)) + }) + .collect::::Scalar>, SuperNovaError>>()?; + let zi_primary_pc_next = zi_primary_pc_next + .expect("zi_primary_pc_next missing") + .get_value() + .ok_or(SuperNovaError::NovaError(NovaError::SynthesisError))?; + let zi_secondary = zi_secondary + .iter() + .map(|v| { + v.get_value() + .ok_or(SuperNovaError::NovaError(NovaError::SynthesisError)) + }) + .collect::::Scalar>, SuperNovaError>>()?; + + if zi_primary.len() != pp[circuit_index].F_arity + || zi_secondary.len() != pp.circuit_shape_secondary.F_arity + { + return Err(SuperNovaError::NovaError( + NovaError::InvalidStepOutputLength, + )); + } + + // clone and updated running instance on respective circuit_index + self.r_U_primary[circuit_index] = Some(r_U_primary_folded); + self.r_W_primary[circuit_index] = Some(r_W_primary_folded); + self.r_W_secondary = r_W_secondary_next; + self.r_U_secondary = r_U_secondary_next; + self.l_w_secondary = l_w_secondary_next; + self.l_u_secondary = l_u_secondary_next; + self.i += 1; + self.zi_primary = zi_primary; + self.zi_secondary = zi_secondary; + self.proven_circuit_index = circuit_index; + self.program_counter = zi_primary_pc_next; + Ok(()) + } + + /// verify recursive snark + pub fn verify, C2: StepCircuit>( + &self, + pp: &PublicParams, + z0_primary: &[E1::Scalar], + z0_secondary: &[E2::Scalar], + ) -> Result<(Vec, Vec), SuperNovaError> { + // number of steps cannot be zero + if self.i == 0 { + debug!("must verify on valid RecursiveSNARK where i > 0"); + return Err(SuperNovaError::NovaError(NovaError::ProofVerifyError)); + } + + // Check lengths of r_primary + if self.r_U_primary.len() != self.num_augmented_circuits + || self.r_W_primary.len() != self.num_augmented_circuits + { + debug!("r_primary length mismatch"); + return Err(SuperNovaError::NovaError(NovaError::ProofVerifyError)); + } + + // Check that there are no missing instance/witness pairs + self + .r_U_primary + .iter() + .zip_eq(self.r_W_primary.iter()) + .enumerate() + .try_for_each(|(i, (u, w))| match (u, w) { + (Some(_), Some(_)) | (None, None) => Ok(()), + _ => { + debug!("r_primary[{:?}]: mismatched instance/witness pair", i); + Err(SuperNovaError::NovaError(NovaError::ProofVerifyError)) + } + })?; + + let circuit_index = self.proven_circuit_index; + + // check we have an instance/witness pair for the circuit_index + if self.r_U_primary[circuit_index].is_none() { + debug!( + "r_primary[{:?}]: instance/witness pair is missing", + circuit_index + ); + return Err(SuperNovaError::NovaError(NovaError::ProofVerifyError)); + } + + // check the (relaxed) R1CS instances public outputs. + { + for (i, r_U_primary_i) in self.r_U_primary.iter().enumerate() { + if let Some(u) = r_U_primary_i { + if u.X.len() != 2 { + debug!( + "r_U_primary[{:?}] got instance length {:?} != 2", + i, + u.X.len(), + ); + return Err(SuperNovaError::NovaError(NovaError::ProofVerifyError)); + } + } + } + + if self.l_u_secondary.X.len() != 2 { + debug!( + "l_U_secondary got instance length {:?} != 2", + self.l_u_secondary.X.len(), + ); + return Err(SuperNovaError::NovaError(NovaError::ProofVerifyError)); + } + + if self.r_U_secondary.X.len() != 2 { + debug!( + "r_U_secondary got instance length {:?} != 2", + self.r_U_secondary.X.len(), + ); + return Err(SuperNovaError::NovaError(NovaError::ProofVerifyError)); + } + } + + let hash_primary = { + let num_absorbs = num_ro_inputs( + self.num_augmented_circuits, + pp.augmented_circuit_params_primary.get_n_limbs(), + pp[circuit_index].F_arity, + true, // is_primary + ); + + let mut hasher = ::RO::new(pp.ro_consts_secondary.clone(), num_absorbs); + hasher.absorb(self.pp_digest); + hasher.absorb(E1::Scalar::from(self.i as u64)); + hasher.absorb(self.program_counter); + + for e in z0_primary { + hasher.absorb(*e); + } + for e in &self.zi_primary { + hasher.absorb(*e); + } + + self.r_U_secondary.absorb_in_ro(&mut hasher); + hasher.squeeze(NUM_HASH_BITS) + }; + + let hash_secondary = { + let num_absorbs = num_ro_inputs( + self.num_augmented_circuits, + pp.augmented_circuit_params_secondary.get_n_limbs(), + pp.circuit_shape_secondary.F_arity, + false, // is_primary + ); + let mut hasher = ::RO::new(pp.ro_consts_primary.clone(), num_absorbs); + hasher.absorb(scalar_as_base::(self.pp_digest)); + hasher.absorb(E2::Scalar::from(self.i as u64)); + + for e in z0_secondary { + hasher.absorb(*e); + } + for e in &self.zi_secondary { + hasher.absorb(*e); + } + + self.r_U_primary.iter().enumerate().for_each(|(i, U)| { + U.as_ref() + .unwrap_or(&RelaxedR1CSInstance::default( + &pp.ck_primary, + &pp[i].r1cs_shape, + )) + .absorb_in_ro(&mut hasher); + }); + hasher.squeeze(NUM_HASH_BITS) + }; + + if hash_primary != self.l_u_secondary.X[0] { + debug!( + "hash_primary {:?} not equal l_u_secondary.X[0] {:?}", + hash_primary, self.l_u_secondary.X[0] + ); + return Err(SuperNovaError::NovaError(NovaError::ProofVerifyError)); + } + if hash_secondary != scalar_as_base::(self.l_u_secondary.X[1]) { + debug!( + "hash_secondary {:?} not equal l_u_secondary.X[1] {:?}", + hash_secondary, self.l_u_secondary.X[1] + ); + return Err(SuperNovaError::NovaError(NovaError::ProofVerifyError)); + } + + // check the satisfiability of all instance/witness pairs + let (res_r_primary, (res_r_secondary, res_l_secondary)) = rayon::join( + || { + self + .r_U_primary + .par_iter() + .zip_eq(self.r_W_primary.par_iter()) + .enumerate() + .try_for_each(|(i, (u, w))| { + if let (Some(u), Some(w)) = (u, w) { + pp[i].r1cs_shape.is_sat_relaxed(&pp.ck_primary, u, w)? + } + Ok(()) + }) + }, + || { + rayon::join( + || { + pp.circuit_shape_secondary.r1cs_shape.is_sat_relaxed( + &pp.ck_secondary, + &self.r_U_secondary, + &self.r_W_secondary, + ) + }, + || { + pp.circuit_shape_secondary.r1cs_shape.is_sat( + &pp.ck_secondary, + &self.l_u_secondary, + &self.l_w_secondary, + ) + }, + ) + }, + ); + + res_r_primary.map_err(|err| match err { + NovaError::UnSatIndex(i) => SuperNovaError::UnSatIndex("r_primary", i), + e => SuperNovaError::NovaError(e), + })?; + res_r_secondary.map_err(|err| match err { + NovaError::UnSatIndex(i) => SuperNovaError::UnSatIndex("r_secondary", i), + e => SuperNovaError::NovaError(e), + })?; + res_l_secondary.map_err(|err| match err { + NovaError::UnSatIndex(i) => SuperNovaError::UnSatIndex("l_secondary", i), + e => SuperNovaError::NovaError(e), + })?; + + Ok((self.zi_primary.clone(), self.zi_secondary.clone())) + } +} + +/// SuperNova helper trait, for implementors that provide sets of sub-circuits to be proved via NIVC. `C1` must be a +/// type (likely an `Enum`) for which a potentially-distinct instance can be supplied for each `index` below +/// `self.num_circuits()`. +pub trait NonUniformCircuit +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ + /// Initial circuit index, defaults to zero. + fn initial_circuit_index(&self) -> usize { + 0 + } + + /// How many circuits are provided? + fn num_circuits(&self) -> usize; + + /// Return a new instance of the primary circuit at `index`. + fn primary_circuit(&self, circuit_index: usize) -> C1; + + /// Return a new instance of the secondary circuit. + fn secondary_circuit(&self) -> C2; +} + +/// Extension trait to simplify getting scalar form of initial circuit index. +pub trait InitialProgramCounter: NonUniformCircuit +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ + /// Initial program counter is the initial circuit index as a `Scalar`. + fn initial_program_counter(&self) -> E1::Scalar { + E1::Scalar::from(self.initial_circuit_index() as u64) + } +} + +impl> InitialProgramCounter + for T +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ +} + +/// Compute the circuit digest of a supernova [StepCircuit]. +/// +/// Note for callers: This function should be called with its performance characteristics in mind. +/// It will synthesize and digest the full `circuit` given. +pub fn circuit_digest< + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C: StepCircuit, +>( + circuit: &C, + num_augmented_circuits: usize, +) -> E1::Scalar { + let augmented_circuit_params = + SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); + + // ro_consts_circuit are parameterized by E2 because the type alias uses E2::Base = E1::Scalar + let ro_consts_circuit: ROConstantsCircuit = ROConstantsCircuit::::default(); + + // Initialize ck for the primary + let augmented_circuit: SuperNovaAugmentedCircuit<'_, E2, C> = SuperNovaAugmentedCircuit::new( + &augmented_circuit_params, + None, + circuit, + ro_consts_circuit, + num_augmented_circuits, + ); + let mut cs: ShapeCS = ShapeCS::new(); + let _ = augmented_circuit.synthesize(&mut cs); + + let F_arity = circuit.arity(); + let circuit_params = CircuitShape::new(cs.r1cs_shape(), F_arity); + circuit_params.digest() +} + +/// Compute the number of absorbs for the random-oracle computing the circuit output +/// X = H(vk, i, pc, z0, zi, U) +fn num_ro_inputs(num_circuits: usize, num_limbs: usize, arity: usize, is_primary: bool) -> usize { + let num_circuits = if is_primary { 1 } else { num_circuits }; + + // [W(x,y,∞), E(x,y,∞), u] + [X0, X1] * #num_limb + let instance_size = 3 + 3 + 1 + 2 * num_limbs; + + 2 // params, i + + usize::from(is_primary) // optional program counter + + 2 * arity // z0, zi + + num_circuits * instance_size +} + +pub mod error; +pub mod snark; +pub(crate) mod utils; + +#[cfg(test)] +mod test; diff --git a/src/supernova/snark.rs b/src/supernova/snark.rs new file mode 100644 index 000000000..476ba4bfe --- /dev/null +++ b/src/supernova/snark.rs @@ -0,0 +1,749 @@ +//! This module defines a final compressing SNARK for supernova proofs + +use super::{error::SuperNovaError, PublicParams, RecursiveSNARK}; +use crate::{ + constants::NUM_HASH_BITS, + r1cs::{R1CSInstance, RelaxedR1CSWitness}, + traits::{ + circuit_supernova::StepCircuit, + snark::{BatchedRelaxedR1CSSNARKTrait, RelaxedR1CSSNARKTrait}, + AbsorbInROTrait, Engine, ROTrait, + }, +}; +use crate::{errors::NovaError, scalar_as_base, RelaxedR1CSInstance, NIFS}; +use ff::PrimeField; +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; + +/// A type that holds the prover key for `CompressedSNARK` +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct ProverKey +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, + S1: BatchedRelaxedR1CSSNARKTrait, + S2: RelaxedR1CSSNARKTrait, +{ + pk_primary: S1::ProverKey, + pk_secondary: S2::ProverKey, + _p: PhantomData<(C1, C2)>, +} + +/// A type that holds the verifier key for `CompressedSNARK` +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct VerifierKey +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, + S1: BatchedRelaxedR1CSSNARKTrait, + S2: RelaxedR1CSSNARKTrait, +{ + vk_primary: S1::VerifierKey, + vk_secondary: S2::VerifierKey, + _p: PhantomData<(C1, C2)>, +} + +/// A SNARK that proves the knowledge of a valid `RecursiveSNARK` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct CompressedSNARK +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, + S1: BatchedRelaxedR1CSSNARKTrait, + S2: RelaxedR1CSSNARKTrait, +{ + r_U_primary: Vec>, + r_W_snark_primary: S1, + + r_U_secondary: RelaxedR1CSInstance, + l_u_secondary: R1CSInstance, + nifs_secondary: NIFS, + f_W_snark_secondary: S2, + + num_steps: usize, + program_counter: E1::Scalar, + + zn_primary: Vec, + zn_secondary: Vec, + _p: PhantomData<(E1, E2, C1, C2, S1, S2)>, +} + +impl CompressedSNARK +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: StepCircuit, + C2: StepCircuit, + S1: BatchedRelaxedR1CSSNARKTrait, + S2: RelaxedR1CSSNARKTrait, +{ + /// Creates prover and verifier keys for `CompressedSNARK` + pub fn setup( + pp: &PublicParams, + ) -> Result< + ( + ProverKey, + VerifierKey, + ), + SuperNovaError, + > { + let (pk_primary, vk_primary) = S1::setup(&pp.ck_primary, pp.primary_r1cs_shapes())?; + + let (pk_secondary, vk_secondary) = + S2::setup(&pp.ck_secondary, &pp.circuit_shape_secondary.r1cs_shape)?; + + let prover_key = ProverKey { + pk_primary, + pk_secondary, + _p: PhantomData, + }; + let verifier_key = VerifierKey { + vk_primary, + vk_secondary, + _p: PhantomData, + }; + + Ok((prover_key, verifier_key)) + } + + /// Create a new `CompressedSNARK` + pub fn prove( + pp: &PublicParams, + pk: &ProverKey, + recursive_snark: &RecursiveSNARK, + ) -> Result { + // fold the secondary circuit's instance + let res_secondary = NIFS::prove( + &pp.ck_secondary, + &pp.ro_consts_secondary, + &scalar_as_base::(pp.digest()), + &pp.circuit_shape_secondary.r1cs_shape, + &recursive_snark.r_U_secondary, + &recursive_snark.r_W_secondary, + &recursive_snark.l_u_secondary, + &recursive_snark.l_w_secondary, + ); + + let (nifs_secondary, (f_U_secondary, f_W_secondary)) = res_secondary?; + + // Prepare the list of primary Relaxed R1CS instances (a default instance is provided for + // uninitialized circuits) + let r_U_primary = recursive_snark + .r_U_primary + .iter() + .enumerate() + .map(|(idx, r_U)| { + r_U + .clone() + .unwrap_or_else(|| RelaxedR1CSInstance::default(&pp.ck_primary, &pp[idx].r1cs_shape)) + }) + .collect::>(); + + // Prepare the list of primary relaxed R1CS witnesses (a default witness is provided for + // uninitialized circuits) + let r_W_primary: Vec> = recursive_snark + .r_W_primary + .iter() + .enumerate() + .map(|(idx, r_W)| { + r_W + .clone() + .unwrap_or_else(|| RelaxedR1CSWitness::default(&pp[idx].r1cs_shape)) + }) + .collect::>(); + + // Generate a primary SNARK proof for the list of primary circuits + let r_W_snark_primary = S1::prove( + &pp.ck_primary, + &pk.pk_primary, + pp.primary_r1cs_shapes(), + &r_U_primary, + &r_W_primary, + )?; + + // Generate a secondary SNARK proof for the secondary circuit + let f_W_snark_secondary = S2::prove( + &pp.ck_secondary, + &pk.pk_secondary, + &pp.circuit_shape_secondary.r1cs_shape, + &f_U_secondary, + &f_W_secondary, + )?; + + let compressed_snark = CompressedSNARK { + r_U_primary, + r_W_snark_primary, + + r_U_secondary: recursive_snark.r_U_secondary.clone(), + l_u_secondary: recursive_snark.l_u_secondary.clone(), + nifs_secondary, + f_W_snark_secondary, + + num_steps: recursive_snark.i, + program_counter: recursive_snark.program_counter, + + zn_primary: recursive_snark.zi_primary.clone(), + zn_secondary: recursive_snark.zi_secondary.clone(), + + _p: PhantomData, + }; + + Ok(compressed_snark) + } + + /// Verify the correctness of the `CompressedSNARK` + pub fn verify( + &self, + pp: &PublicParams, + vk: &VerifierKey, + z0_primary: &[E1::Scalar], + z0_secondary: &[E2::Scalar], + ) -> Result<(Vec, Vec), SuperNovaError> { + let last_circuit_idx = field_as_usize(self.program_counter); + + let num_field_primary_ro = 3 // params_next, i_new, program_counter_new + + 2 * pp[last_circuit_idx].F_arity // zo, z1 + + (7 + 2 * pp.augmented_circuit_params_primary.get_n_limbs()); // # 1 * (7 + [X0, X1]*#num_limb) + + // secondary circuit + // NOTE: This count ensure the number of witnesses sent by the prover must equal the number of + // NIVC circuits + let num_field_secondary_ro = 2 // params_next, i_new + + 2 * pp.circuit_shape_secondary.F_arity // zo, z1 + + pp.circuit_shapes.len() * (7 + 2 * pp.augmented_circuit_params_primary.get_n_limbs()); // #num_augment + + // Compute the primary and secondary hashes given the digest, program counter, instances, and + // witnesses provided by the prover + let (hash_primary, hash_secondary) = { + let mut hasher = + ::RO::new(pp.ro_consts_secondary.clone(), num_field_primary_ro); + + hasher.absorb(pp.digest()); + hasher.absorb(E1::Scalar::from(self.num_steps as u64)); + hasher.absorb(self.program_counter); + + for e in z0_primary { + hasher.absorb(*e); + } + + for e in &self.zn_primary { + hasher.absorb(*e); + } + + self.r_U_secondary.absorb_in_ro(&mut hasher); + + let mut hasher2 = + ::RO::new(pp.ro_consts_primary.clone(), num_field_secondary_ro); + + hasher2.absorb(scalar_as_base::(pp.digest())); + hasher2.absorb(E2::Scalar::from(self.num_steps as u64)); + + for e in z0_secondary { + hasher2.absorb(*e); + } + + for e in &self.zn_secondary { + hasher2.absorb(*e); + } + + self.r_U_primary.iter().for_each(|U| { + U.absorb_in_ro(&mut hasher2); + }); + + ( + hasher.squeeze(NUM_HASH_BITS), + hasher2.squeeze(NUM_HASH_BITS), + ) + }; + + // Compare the computed hashes with the public IO of the last invocation of `prove_step` + if hash_primary != self.l_u_secondary.X[0] { + return Err(NovaError::ProofVerifyError.into()); + } + + if hash_secondary != scalar_as_base::(self.l_u_secondary.X[1]) { + return Err(NovaError::ProofVerifyError.into()); + } + + // Verify the primary SNARK + let res_primary = self + .r_W_snark_primary + .verify(&vk.vk_primary, &self.r_U_primary); + + // Fold the secondary circuit's instance + let f_U_secondary = self.nifs_secondary.verify( + &pp.ro_consts_secondary, + &scalar_as_base::(pp.digest()), + &self.r_U_secondary, + &self.l_u_secondary, + )?; + + // Verify the secondary SNARK + let res_secondary = self + .f_W_snark_secondary + .verify(&vk.vk_secondary, &f_U_secondary); + + res_primary?; + + res_secondary?; + + Ok((self.zn_primary.clone(), self.zn_secondary.clone())) + } +} + +fn field_as_usize(x: F) -> usize { + u32::from_le_bytes(x.to_repr().as_ref()[0..4].try_into().unwrap()) as usize +} + +#[cfg(test)] +mod test { + use super::*; + use crate::{ + provider::{ + ipa_pc, Bn256Engine, GrumpkinEngine, PallasEngine, Secp256k1Engine, Secq256k1Engine, + VestaEngine, + }, + spartan::{batched, batched_ppsnark, snark::RelaxedR1CSSNARK}, + supernova::NonUniformCircuit, + traits::circuit_supernova::TrivialSecondaryCircuit, + }; + + use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; + use ff::Field; + + type EE = ipa_pc::EvaluationEngine; + type S1 = batched::BatchedRelaxedR1CSSNARK>; + type S1PP = batched_ppsnark::BatchedRelaxedR1CSSNARK>; + type S2 = RelaxedR1CSSNARK>; + + #[derive(Clone)] + struct SquareCircuit { + _p: PhantomData, + } + + impl StepCircuit for SquareCircuit { + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + 0 + } + + fn synthesize>( + &self, + cs: &mut CS, + _pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result< + ( + Option>, + Vec>, + ), + SynthesisError, + > { + let z_i = &z[0]; + + let z_next = z_i.square(cs.namespace(|| "z_i^2"))?; + + let next_pc = AllocatedNum::alloc(cs.namespace(|| "next_pc"), || Ok(E::Scalar::from(1u64)))?; + + cs.enforce( + || "next_pc = 1", + |lc| lc + CS::one(), + |lc| lc + next_pc.get_variable(), + |lc| lc + CS::one(), + ); + + Ok((Some(next_pc), vec![z_next])) + } + } + + #[derive(Clone)] + struct CubeCircuit { + _p: PhantomData, + } + + impl StepCircuit for CubeCircuit { + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + 1 + } + + fn synthesize>( + &self, + cs: &mut CS, + _pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result< + ( + Option>, + Vec>, + ), + SynthesisError, + > { + let z_i = &z[0]; + + let z_sq = z_i.square(cs.namespace(|| "z_i^2"))?; + let z_cu = z_sq.mul(cs.namespace(|| "z_i^3"), z_i)?; + + let next_pc = AllocatedNum::alloc(cs.namespace(|| "next_pc"), || Ok(E::Scalar::from(0u64)))?; + + cs.enforce( + || "next_pc = 0", + |lc| lc + CS::one(), + |lc| lc + next_pc.get_variable(), + |lc| lc, + ); + + Ok((Some(next_pc), vec![z_cu])) + } + } + + #[derive(Clone)] + enum TestCircuit { + Square(SquareCircuit), + Cube(CubeCircuit), + } + + impl TestCircuit { + fn new(num_steps: usize) -> Vec { + let mut circuits = Vec::new(); + + for idx in 0..num_steps { + if idx % 2 == 0 { + circuits.push(Self::Square(SquareCircuit { _p: PhantomData })) + } else { + circuits.push(Self::Cube(CubeCircuit { _p: PhantomData })) + } + } + + circuits + } + } + + impl StepCircuit for TestCircuit { + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + match self { + TestCircuit::Square(c) => c.circuit_index(), + TestCircuit::Cube(c) => c.circuit_index(), + } + } + + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result< + ( + Option>, + Vec>, + ), + SynthesisError, + > { + match self { + TestCircuit::Square(c) => c.synthesize(cs, pc, z), + TestCircuit::Cube(c) => c.synthesize(cs, pc, z), + } + } + } + + impl NonUniformCircuit, TrivialSecondaryCircuit> + for TestCircuit + where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + { + fn num_circuits(&self) -> usize { + 2 + } + + fn primary_circuit(&self, circuit_index: usize) -> TestCircuit { + match circuit_index { + 0 => Self::Square(SquareCircuit { _p: PhantomData }), + 1 => Self::Cube(CubeCircuit { _p: PhantomData }), + _ => panic!("Invalid circuit index"), + } + } + + fn secondary_circuit(&self) -> TrivialSecondaryCircuit { + Default::default() + } + } + + fn test_nivc_trivial_with_compression_with() + where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S1: BatchedRelaxedR1CSSNARKTrait, + S2: RelaxedR1CSSNARKTrait, + { + const NUM_STEPS: usize = 6; + + let secondary_circuit = TrivialSecondaryCircuit::default(); + let test_circuits = TestCircuit::new(NUM_STEPS); + + let pp = PublicParams::setup(&test_circuits[0], &*S1::ck_floor(), &*S2::ck_floor()); + + let z0_primary = vec![E1::Scalar::from(17u64)]; + let z0_secondary = vec![::Scalar::ZERO]; + + let mut recursive_snark = RecursiveSNARK::new( + &pp, + &test_circuits[0], + &test_circuits[0], + &secondary_circuit, + &z0_primary, + &z0_secondary, + ) + .unwrap(); + + for circuit in test_circuits.iter().take(NUM_STEPS) { + let prove_res = recursive_snark.prove_step(&pp, circuit, &secondary_circuit); + + let verify_res = recursive_snark.verify(&pp, &z0_primary, &z0_secondary); + + assert!(prove_res.is_ok()); + assert!(verify_res.is_ok()); + } + + let (prover_key, verifier_key) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); + + let compressed_prove_res = CompressedSNARK::prove(&pp, &prover_key, &recursive_snark); + + assert!(compressed_prove_res.is_ok()); + + let compressed_snark = compressed_prove_res.unwrap(); + + let compressed_verify_res = + compressed_snark.verify(&pp, &verifier_key, &z0_primary, &z0_secondary); + + assert!(compressed_verify_res.is_ok()); + } + + #[test] + fn test_nivc_trivial_with_compression() { + // ppSNARK + test_nivc_trivial_with_compression_with::, S2<_>>(); + test_nivc_trivial_with_compression_with::, S2<_>>(); + test_nivc_trivial_with_compression_with::, S2<_>>(); + // classic SNARK + test_nivc_trivial_with_compression_with::, S2<_>>(); + test_nivc_trivial_with_compression_with::, S2<_>>(); + test_nivc_trivial_with_compression_with::, S2<_>>(); + } + + #[derive(Clone)] + struct BigPowerCircuit { + _p: PhantomData, + } + + impl StepCircuit for BigPowerCircuit { + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + 1 + } + + fn synthesize>( + &self, + cs: &mut CS, + _pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result< + ( + Option>, + Vec>, + ), + SynthesisError, + > { + let mut x = z[0].clone(); + let mut y = x.clone(); + for i in 0..10_000 { + y = x.square(cs.namespace(|| format!("x_sq_{i}")))?; + x = y.clone(); + } + + let next_pc = AllocatedNum::alloc(cs.namespace(|| "next_pc"), || Ok(E::Scalar::from(0u64)))?; + + cs.enforce( + || "next_pc = 0", + |lc| lc + CS::one(), + |lc| lc + next_pc.get_variable(), + |lc| lc, + ); + + Ok((Some(next_pc), vec![y])) + } + } + + #[derive(Clone)] + enum BigTestCircuit { + Square(SquareCircuit), + BigPower(BigPowerCircuit), + } + + impl BigTestCircuit { + fn new(num_steps: usize) -> Vec { + let mut circuits = Vec::new(); + + for idx in 0..num_steps { + if idx % 2 == 0 { + circuits.push(Self::Square(SquareCircuit { _p: PhantomData })) + } else { + circuits.push(Self::BigPower(BigPowerCircuit { _p: PhantomData })) + } + } + + circuits + } + } + + impl StepCircuit for BigTestCircuit { + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + match self { + BigTestCircuit::Square(c) => c.circuit_index(), + BigTestCircuit::BigPower(c) => c.circuit_index(), + } + } + + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result< + ( + Option>, + Vec>, + ), + SynthesisError, + > { + match self { + BigTestCircuit::Square(c) => c.synthesize(cs, pc, z), + BigTestCircuit::BigPower(c) => c.synthesize(cs, pc, z), + } + } + } + + impl NonUniformCircuit, TrivialSecondaryCircuit> + for BigTestCircuit + where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + { + fn num_circuits(&self) -> usize { + 2 + } + + fn primary_circuit(&self, circuit_index: usize) -> BigTestCircuit { + match circuit_index { + 0 => Self::Square(SquareCircuit { _p: PhantomData }), + 1 => Self::BigPower(BigPowerCircuit { _p: PhantomData }), + _ => panic!("Invalid circuit index"), + } + } + + fn secondary_circuit(&self) -> TrivialSecondaryCircuit { + Default::default() + } + } + + fn test_compression_with_circuit_size_difference_with() + where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S1: BatchedRelaxedR1CSSNARKTrait, + S2: RelaxedR1CSSNARKTrait, + { + const NUM_STEPS: usize = 4; + + let secondary_circuit = TrivialSecondaryCircuit::default(); + let test_circuits = BigTestCircuit::new(NUM_STEPS); + + let pp = PublicParams::setup(&test_circuits[0], &*S1::ck_floor(), &*S2::ck_floor()); + + let z0_primary = vec![E1::Scalar::from(17u64)]; + let z0_secondary = vec![::Scalar::ZERO]; + + let mut recursive_snark = RecursiveSNARK::new( + &pp, + &test_circuits[0], + &test_circuits[0], + &secondary_circuit, + &z0_primary, + &z0_secondary, + ) + .unwrap(); + + for circuit in test_circuits.iter().take(NUM_STEPS) { + let prove_res = recursive_snark.prove_step(&pp, circuit, &secondary_circuit); + + let verify_res = recursive_snark.verify(&pp, &z0_primary, &z0_secondary); + + assert!(prove_res.is_ok()); + assert!(verify_res.is_ok()); + } + + let (prover_key, verifier_key) = CompressedSNARK::<_, _, _, _, S1, S2>::setup(&pp).unwrap(); + + let compressed_prove_res = CompressedSNARK::prove(&pp, &prover_key, &recursive_snark); + + assert!(compressed_prove_res.is_ok()); + + let compressed_snark = compressed_prove_res.unwrap(); + + let compressed_verify_res = + compressed_snark.verify(&pp, &verifier_key, &z0_primary, &z0_secondary); + + assert!(compressed_verify_res.is_ok()); + } + + #[test] + fn test_compression_with_circuit_size_difference() { + // ppSNARK + test_compression_with_circuit_size_difference_with::, S2<_>>( + ); + test_compression_with_circuit_size_difference_with::, S2<_>>( + ); + test_compression_with_circuit_size_difference_with::< + Secp256k1Engine, + Secq256k1Engine, + S1PP<_>, + S2<_>, + >(); + // classic SNARK + test_compression_with_circuit_size_difference_with::, S2<_>>(); + test_compression_with_circuit_size_difference_with::, S2<_>>( + ); + test_compression_with_circuit_size_difference_with::< + Secp256k1Engine, + Secq256k1Engine, + S1<_>, + S2<_>, + >(); + } +} diff --git a/src/supernova/test.rs b/src/supernova/test.rs new file mode 100644 index 000000000..e69eeb972 --- /dev/null +++ b/src/supernova/test.rs @@ -0,0 +1,904 @@ +use crate::gadgets::utils::alloc_zero; +use crate::provider::poseidon::PoseidonConstantsCircuit; +use crate::provider::Bn256Engine; +use crate::provider::GrumpkinEngine; +use crate::provider::PallasEngine; +use crate::provider::Secp256k1Engine; +use crate::provider::Secq256k1Engine; +use crate::provider::VestaEngine; +use crate::traits::circuit_supernova::{ + EnforcingStepCircuit, StepCircuit, TrivialSecondaryCircuit, TrivialTestCircuit, +}; +use crate::traits::snark::default_ck_hint; +use crate::{bellpepper::test_shape_cs::TestShapeCS, gadgets::utils::alloc_one}; +use bellpepper_core::num::AllocatedNum; +use bellpepper_core::{ConstraintSystem, SynthesisError}; +use core::marker::PhantomData; +use ff::Field; +use ff::PrimeField; +use std::fmt::Write; +use tap::TapOptional; + +use super::{utils::get_selector_vec_from_index, *}; + +#[derive(Clone, Debug, Default)] +struct CubicCircuit { + _p: PhantomData, + circuit_index: usize, + rom_size: usize, +} + +impl CubicCircuit +where + F: PrimeField, +{ + fn new(circuit_index: usize, rom_size: usize) -> Self { + CubicCircuit { + circuit_index, + rom_size, + _p: PhantomData, + } + } +} + +fn next_rom_index_and_pc>( + cs: &mut CS, + rom_index: &AllocatedNum, + allocated_rom: &[AllocatedNum], + pc: &AllocatedNum, +) -> Result<(AllocatedNum, AllocatedNum), SynthesisError> { + // Compute a selector for the current rom_index in allocated_rom + let current_rom_selector = get_selector_vec_from_index( + cs.namespace(|| "rom selector"), + rom_index, + allocated_rom.len(), + )?; + + // Enforce that allocated_rom[rom_index] = pc + for (rom, bit) in allocated_rom.iter().zip_eq(current_rom_selector.iter()) { + // if bit = 1, then rom = pc + // bit * (rom - pc) = 0 + cs.enforce( + || "enforce bit = 1 => rom = pc", + |lc| lc + &bit.lc(CS::one(), F::ONE), + |lc| lc + rom.get_variable() - pc.get_variable(), + |lc| lc, + ); + } + + // Get the index of the current rom, or the index of the invalid rom if no match + let current_rom_index = current_rom_selector + .iter() + .position(|bit| bit.get_value().is_some_and(|v| v)) + .unwrap_or_default(); + let next_rom_index = current_rom_index + 1; + + let rom_index_next = AllocatedNum::alloc_infallible(cs.namespace(|| "next rom index"), || { + F::from(next_rom_index as u64) + }); + cs.enforce( + || " rom_index + 1 - next_rom_index_num = 0", + |lc| lc, + |lc| lc, + |lc| lc + rom_index.get_variable() + CS::one() - rom_index_next.get_variable(), + ); + + // Allocate the next pc without checking. + // The next iteration will check whether the next pc is valid. + let pc_next = AllocatedNum::alloc_infallible(cs.namespace(|| "next pc"), || { + allocated_rom + .get(next_rom_index) + .and_then(|v| v.get_value()) + .unwrap_or(-F::ONE) + }); + + Ok((rom_index_next, pc_next)) +} + +impl StepCircuit for CubicCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + 2 + self.rom_size // value + rom_pc + rom[].len() + } + + fn circuit_index(&self) -> usize { + self.circuit_index + } + + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + let rom_index = &z[1]; + let allocated_rom = &z[2..]; + + let (rom_index_next, pc_next) = next_rom_index_and_pc( + &mut cs.namespace(|| "next and rom_index and pc"), + rom_index, + allocated_rom, + pc.ok_or(SynthesisError::AssignmentMissing)?, + )?; + + // Consider a cubic equation: `x^3 + x + 5 = y`, where `x` and `y` are respectively the input and output. + let x = &z[0]; + let x_sq = x.square(cs.namespace(|| "x_sq"))?; + let x_cu = x_sq.mul(cs.namespace(|| "x_cu"), x)?; + let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { + Ok(x_cu.get_value().unwrap() + x.get_value().unwrap() + F::from(5u64)) + })?; + + cs.enforce( + || "y = x^3 + x + 5", + |lc| { + lc + x_cu.get_variable() + + x.get_variable() + + CS::one() + + CS::one() + + CS::one() + + CS::one() + + CS::one() + }, + |lc| lc + CS::one(), + |lc| lc + y.get_variable(), + ); + + let mut z_next = vec![y]; + z_next.push(rom_index_next); + z_next.extend(z[2..].iter().cloned()); + Ok((Some(pc_next), z_next)) + } +} + +#[derive(Clone, Debug, Default)] +struct SquareCircuit { + _p: PhantomData, + circuit_index: usize, + rom_size: usize, +} + +impl SquareCircuit +where + F: PrimeField, +{ + fn new(circuit_index: usize, rom_size: usize) -> Self { + SquareCircuit { + circuit_index, + rom_size, + _p: PhantomData, + } + } +} + +impl StepCircuit for SquareCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + 2 + self.rom_size // value + rom_pc + rom[].len() + } + + fn circuit_index(&self) -> usize { + self.circuit_index + } + + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + let rom_index = &z[1]; + let allocated_rom = &z[2..]; + + let (rom_index_next, pc_next) = next_rom_index_and_pc( + &mut cs.namespace(|| "next and rom_index and pc"), + rom_index, + allocated_rom, + pc.ok_or(SynthesisError::AssignmentMissing)?, + )?; + + // Consider an equation: `x^2 + x + 5 = y`, where `x` and `y` are respectively the input and output. + let x = &z[0]; + let x_sq = x.square(cs.namespace(|| "x_sq"))?; + let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { + Ok(x_sq.get_value().unwrap() + x.get_value().unwrap() + F::from(5u64)) + })?; + + cs.enforce( + || "y = x^2 + x + 5", + |lc| { + lc + x_sq.get_variable() + + x.get_variable() + + CS::one() + + CS::one() + + CS::one() + + CS::one() + + CS::one() + }, + |lc| lc + CS::one(), + |lc| lc + y.get_variable(), + ); + + let mut z_next = vec![y]; + z_next.push(rom_index_next); + z_next.extend(z[2..].iter().cloned()); + Ok((Some(pc_next), z_next)) + } +} + +fn print_constraints_name_on_error_index( + err: &SuperNovaError, + pp: &PublicParams, + c_primary: &C1, + c_secondary: &C2, + num_augmented_circuits: usize, +) where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + C1: EnforcingStepCircuit, + C2: EnforcingStepCircuit, +{ + match err { + SuperNovaError::UnSatIndex(msg, index) if *msg == "r_primary" => { + let circuit_primary: SuperNovaAugmentedCircuit<'_, E2, C1> = SuperNovaAugmentedCircuit::new( + &pp.augmented_circuit_params_primary, + None, + c_primary, + pp.ro_consts_circuit_primary.clone(), + num_augmented_circuits, + ); + let mut cs: TestShapeCS = TestShapeCS::new(); + let _ = circuit_primary.synthesize(&mut cs); + cs.constraints + .get(*index) + .tap_some(|constraint| debug!("{msg} failed at constraint {}", constraint.3)); + } + SuperNovaError::UnSatIndex(msg, index) if *msg == "r_secondary" || *msg == "l_secondary" => { + let circuit_secondary: SuperNovaAugmentedCircuit<'_, E1, C2> = SuperNovaAugmentedCircuit::new( + &pp.augmented_circuit_params_secondary, + None, + c_secondary, + pp.ro_consts_circuit_secondary.clone(), + num_augmented_circuits, + ); + let mut cs: TestShapeCS = TestShapeCS::new(); + let _ = circuit_secondary.synthesize(&mut cs); + cs.constraints + .get(*index) + .tap_some(|constraint| debug!("{msg} failed at constraint {}", constraint.3)); + } + _ => (), + } +} + +const OPCODE_0: usize = 0; +const OPCODE_1: usize = 1; + +struct TestROM +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S: EnforcingStepCircuit + Default, +{ + rom: Vec, + _p: PhantomData<(E1, E2, S)>, +} + +#[derive(Debug, Clone)] +enum TestROMCircuit { + Cubic(CubicCircuit), + Square(SquareCircuit), +} + +impl StepCircuit for TestROMCircuit { + fn arity(&self) -> usize { + match self { + Self::Cubic(x) => x.arity(), + Self::Square(x) => x.arity(), + } + } + + fn circuit_index(&self) -> usize { + match self { + Self::Cubic(x) => x.circuit_index(), + Self::Square(x) => x.circuit_index(), + } + } + + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + match self { + Self::Cubic(x) => x.synthesize(cs, pc, z), + Self::Square(x) => x.synthesize(cs, pc, z), + } + } +} + +impl + NonUniformCircuit, TrivialSecondaryCircuit> + for TestROM> +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, +{ + fn num_circuits(&self) -> usize { + 2 + } + + fn primary_circuit(&self, circuit_index: usize) -> TestROMCircuit { + match circuit_index { + 0 => TestROMCircuit::Cubic(CubicCircuit::new(circuit_index, self.rom.len())), + 1 => TestROMCircuit::Square(SquareCircuit::new(circuit_index, self.rom.len())), + _ => panic!("unsupported primary circuit index"), + } + } + + fn secondary_circuit(&self) -> TrivialSecondaryCircuit { + Default::default() + } + + fn initial_circuit_index(&self) -> usize { + self.rom[0] + } +} + +impl TestROM +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + S: EnforcingStepCircuit + Default, +{ + fn new(rom: Vec) -> Self { + Self { + rom, + _p: Default::default(), + } + } +} + +fn test_trivial_nivc_with() +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, +{ + // Here demo a simple RAM machine + // - with 2 argumented circuit + // - each argumented circuit contains primary and secondary circuit + // - a memory commmitment via a public IO `rom` (like a program) to constraint the sequence execution + + // This test also ready to add more argumented circuit and ROM can be arbitrary length + + // ROM is for constraints the sequence of execution order for opcode + + // TODO: replace with memory commitment along with suggestion from Supernova 4.4 optimisations + + // This is mostly done with the existing Nova code. With additions of U_i[] and program_counter checks + // in the augmented circuit. + + let rom = vec![ + OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, + OPCODE_1, + ]; // Rom can be arbitrary length. + + let test_rom = TestROM::>::new(rom); + + let pp = PublicParams::setup(&test_rom, &*default_ck_hint(), &*default_ck_hint()); + + // extend z0_primary/secondary with rom content + let mut z0_primary = vec![::Scalar::ONE]; + z0_primary.push(::Scalar::ZERO); // rom_index = 0 + z0_primary.extend( + test_rom + .rom + .iter() + .map(|opcode| ::Scalar::from(*opcode as u64)), + ); + let z0_secondary = vec![::Scalar::ONE]; + + let mut recursive_snark_option: Option> = None; + + for &op_code in test_rom.rom.iter() { + let circuit_primary = test_rom.primary_circuit(op_code); + let circuit_secondary = test_rom.secondary_circuit(); + + let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { + RecursiveSNARK::new( + &pp, + &test_rom, + &circuit_primary, + &circuit_secondary, + &z0_primary, + &z0_secondary, + ) + .unwrap() + }); + + recursive_snark + .prove_step(&pp, &circuit_primary, &circuit_secondary) + .unwrap(); + recursive_snark + .verify(&pp, &z0_primary, &z0_secondary) + .map_err(|err| { + print_constraints_name_on_error_index( + &err, + &pp, + &circuit_primary, + &circuit_secondary, + test_rom.num_circuits(), + ) + }) + .unwrap(); + + recursive_snark_option = Some(recursive_snark) + } + + assert!(recursive_snark_option.is_some()); + + // Now you can handle the Result using if let + let RecursiveSNARK { + zi_primary, + zi_secondary, + program_counter, + .. + } = &recursive_snark_option.unwrap(); + + println!("zi_primary: {:?}", zi_primary); + println!("zi_secondary: {:?}", zi_secondary); + println!("final program_counter: {:?}", program_counter); + + // The final program counter should be -1 + assert_eq!(*program_counter, -::Scalar::ONE); +} + +#[test] +fn test_trivial_nivc() { + // Expirementing with selecting the running claims for nifs + test_trivial_nivc_with::(); +} + +// In the following we use 1 to refer to the primary, and 2 to refer to the secondary circuit +fn test_recursive_circuit_with( + primary_params: &SuperNovaAugmentedCircuitParams, + secondary_params: &SuperNovaAugmentedCircuitParams, + ro_consts1: ROConstantsCircuit, + ro_consts2: ROConstantsCircuit, + num_constraints_primary: usize, + num_constraints_secondary: usize, +) where + E1: Engine::Scalar>, + E2: Engine::Scalar>, +{ + // Initialize the shape and ck for the primary + let step_circuit1 = TrivialTestCircuit::default(); + let arity1 = step_circuit1.arity(); + let circuit1: SuperNovaAugmentedCircuit<'_, E2, TrivialTestCircuit<::Base>> = + SuperNovaAugmentedCircuit::new(primary_params, None, &step_circuit1, ro_consts1.clone(), 2); + let mut cs: ShapeCS = ShapeCS::new(); + if let Err(e) = circuit1.synthesize(&mut cs) { + panic!("{}", e) + } + let (shape1, ck1) = cs.r1cs_shape_and_key(&*default_ck_hint()); + assert_eq!(cs.num_constraints(), num_constraints_primary); + + // Initialize the shape and ck for the secondary + let step_circuit2 = TrivialSecondaryCircuit::default(); + let arity2 = step_circuit2.arity(); + let circuit2: SuperNovaAugmentedCircuit<'_, E1, TrivialSecondaryCircuit<::Base>> = + SuperNovaAugmentedCircuit::new( + secondary_params, + None, + &step_circuit2, + ro_consts2.clone(), + 2, + ); + let mut cs: ShapeCS = ShapeCS::new(); + if let Err(e) = circuit2.synthesize(&mut cs) { + panic!("{}", e) + } + let (shape2, ck2) = cs.r1cs_shape_and_key(&*default_ck_hint()); + assert_eq!(cs.num_constraints(), num_constraints_secondary); + + // Execute the base case for the primary + let zero1 = <::Base as Field>::ZERO; + let z0 = vec![zero1; arity1]; + let mut cs1 = SatisfyingAssignment::::new(); + let inputs1: SuperNovaAugmentedCircuitInputs<'_, E2> = SuperNovaAugmentedCircuitInputs::new( + scalar_as_base::(zero1), // pass zero for testing + zero1, + &z0, + None, + None, + None, + None, + Some(zero1), + zero1, + ); + let step_circuit = TrivialTestCircuit::default(); + let circuit1: SuperNovaAugmentedCircuit<'_, E2, TrivialTestCircuit<::Base>> = + SuperNovaAugmentedCircuit::new(primary_params, Some(inputs1), &step_circuit, ro_consts1, 2); + if let Err(e) = circuit1.synthesize(&mut cs1) { + panic!("{}", e) + } + let (inst1, witness1) = cs1.r1cs_instance_and_witness(&shape1, &ck1).unwrap(); + // Make sure that this is satisfiable + assert!(shape1.is_sat(&ck1, &inst1, &witness1).is_ok()); + + // Execute the base case for the secondary + let zero2 = <::Base as Field>::ZERO; + let z0 = vec![zero2; arity2]; + let mut cs2 = SatisfyingAssignment::::new(); + let inputs2: SuperNovaAugmentedCircuitInputs<'_, E1> = SuperNovaAugmentedCircuitInputs::new( + scalar_as_base::(zero2), // pass zero for testing + zero2, + &z0, + None, + None, + Some(&inst1), + None, + None, + zero2, + ); + let step_circuit = TrivialSecondaryCircuit::default(); + let circuit2: SuperNovaAugmentedCircuit<'_, E1, TrivialSecondaryCircuit<::Base>> = + SuperNovaAugmentedCircuit::new( + secondary_params, + Some(inputs2), + &step_circuit, + ro_consts2, + 2, + ); + if let Err(e) = circuit2.synthesize(&mut cs2) { + panic!("{}", e) + } + let (inst2, witness2) = cs2.r1cs_instance_and_witness(&shape2, &ck2).unwrap(); + // Make sure that it is satisfiable + assert!(shape2.is_sat(&ck2, &inst2, &witness2).is_ok()); +} + +#[test] +fn test_recursive_circuit() { + let params1 = SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, true); + let params2 = SuperNovaAugmentedCircuitParams::new(BN_LIMB_WIDTH, BN_N_LIMBS, false); + let ro_consts1: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + let ro_consts2: ROConstantsCircuit = PoseidonConstantsCircuit::default(); + + test_recursive_circuit_with::( + ¶ms1, ¶ms2, ro_consts1, ro_consts2, 9844, 12025, + ); +} + +fn test_pp_digest_with(non_uniform_circuit: &NC, expected: &str) +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, + T1: StepCircuit, + T2: StepCircuit, + NC: NonUniformCircuit, +{ + // TODO: add back in https://github.com/lurk-lab/arecibo/issues/53 + // // this tests public parameters with a size specifically intended for a spark-compressed SNARK + // let pp_hint1 = Some(SPrime::::commitment_key_floor()); + // let pp_hint2 = Some(SPrime::::commitment_key_floor()); + let pp = PublicParams::::setup( + non_uniform_circuit, + &*default_ck_hint(), + &*default_ck_hint(), + ); + + let digest_str = pp + .digest() + .to_repr() + .as_ref() + .iter() + .fold(String::new(), |mut output, b| { + let _ = write!(output, "{b:02x}"); + output + }); + assert_eq!(digest_str, expected); +} + +#[test] +fn test_supernova_pp_digest() { + let rom = vec![ + OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, + OPCODE_1, + ]; // Rom can be arbitrary length. + let test_rom = TestROM::< + PallasEngine, + VestaEngine, + TrivialSecondaryCircuit<::Scalar>, + >::new(rom); + + test_pp_digest_with::( + &test_rom, + "7e203fdfeab0ee8f56f8948497f8de73539d52e64cef89e44fff84711cf8b100", + ); + + let rom = vec![ + OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, + OPCODE_1, + ]; // Rom can be arbitrary length. + let test_rom_grumpkin = TestROM::< + Bn256Engine, + GrumpkinEngine, + TrivialSecondaryCircuit<::Scalar>, + >::new(rom); + + test_pp_digest_with::( + &test_rom_grumpkin, + "6f72db6927b6a12e95e1d5237298e1e20f0215b63ef8d76a361930eb76f71003", + ); + + let rom = vec![ + OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, + OPCODE_1, + ]; // Rom can be arbitrary length. + let test_rom_secp = TestROM::< + Secp256k1Engine, + Secq256k1Engine, + TrivialSecondaryCircuit<::Scalar>, + >::new(rom); + + test_pp_digest_with::( + &test_rom_secp, + "0c2f7c68efcc5f4c42a25670ea896bc082c9753d04fc2e5b3a41531ed4e91602", + ); +} + +// y is a non-deterministic hint representing the cube root of the input at a step. +#[derive(Clone, Debug)] +struct CubeRootCheckingCircuit { + y: Option, +} + +impl StepCircuit for CubeRootCheckingCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + 0 + } + + fn synthesize>( + &self, + cs: &mut CS, + _pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + let x = &z[0]; + + // we allocate a variable and set it to the provided non-deterministic hint. + let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { + self.y.ok_or(SynthesisError::AssignmentMissing) + })?; + + // We now check if y = x^{1/3} by checking if y^3 = x + let y_sq = y.square(cs.namespace(|| "y_sq"))?; + let y_cube = y_sq.mul(cs.namespace(|| "y_cube"), &y)?; + + cs.enforce( + || "y^3 = x", + |lc| lc + y_cube.get_variable(), + |lc| lc + CS::one(), + |lc| lc + x.get_variable(), + ); + + let next_pc = alloc_one(&mut cs.namespace(|| "next_pc")); + + Ok((Some(next_pc), vec![y])) + } +} + +// y is a non-deterministic hint representing the fifth root of the input at a step. +#[derive(Clone, Debug)] +struct FifthRootCheckingCircuit { + y: Option, +} + +impl StepCircuit for FifthRootCheckingCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + 1 + } + + fn synthesize>( + &self, + cs: &mut CS, + _pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + let x = &z[0]; + + // we allocate a variable and set it to the provided non-deterministic hint. + let y = AllocatedNum::alloc(cs.namespace(|| "y"), || { + self.y.ok_or(SynthesisError::AssignmentMissing) + })?; + + // We now check if y = x^{1/5} by checking if y^5 = x + let y_sq = y.square(cs.namespace(|| "y_sq"))?; + let y_quad = y_sq.square(cs.namespace(|| "y_quad"))?; + let y_pow_5 = y_quad.mul(cs.namespace(|| "y_fifth"), &y)?; + + cs.enforce( + || "y^5 = x", + |lc| lc + y_pow_5.get_variable(), + |lc| lc + CS::one(), + |lc| lc + x.get_variable(), + ); + + let next_pc = alloc_zero(&mut cs.namespace(|| "next_pc")); + + Ok((Some(next_pc), vec![y])) + } +} + +#[derive(Clone, Debug)] +enum RootCheckingCircuit { + Cube(CubeRootCheckingCircuit), + Fifth(FifthRootCheckingCircuit), +} + +impl RootCheckingCircuit { + fn new(num_steps: usize) -> (Vec, Vec) { + let mut powers = Vec::new(); + let rng = &mut rand::rngs::OsRng; + let mut seed = F::random(rng); + + for i in 0..num_steps + 1 { + let seed_sq = seed.clone().square(); + // Cube-root and fifth-root circuits alternate. We compute the hints backward, so the calculations appear to be + // associated with the 'wrong' circuit. The final circuit is discarded, and only the final seed is used (as z_0). + powers.push(if i % 2 == num_steps % 2 { + seed *= seed_sq; + Self::Fifth(FifthRootCheckingCircuit { y: Some(seed) }) + } else { + seed *= seed_sq.clone().square(); + Self::Cube(CubeRootCheckingCircuit { y: Some(seed) }) + }) + } + + // reverse the powers to get roots + let roots = powers.into_iter().rev().collect::>(); + (vec![roots[0].get_y().unwrap()], roots[1..].to_vec()) + } + + fn get_y(&self) -> Option { + match self { + Self::Fifth(x) => x.y, + Self::Cube(x) => x.y, + } + } +} + +impl StepCircuit for RootCheckingCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + match self { + Self::Cube(x) => x.circuit_index(), + Self::Fifth(x) => x.circuit_index(), + } + } + + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + match self { + Self::Cube(c) => c.synthesize(cs, pc, z), + Self::Fifth(c) => c.synthesize(cs, pc, z), + } + } +} + +impl + NonUniformCircuit, TrivialSecondaryCircuit> + for RootCheckingCircuit +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, +{ + fn num_circuits(&self) -> usize { + 2 + } + + fn primary_circuit(&self, circuit_index: usize) -> Self { + match circuit_index { + 0 => Self::Cube(CubeRootCheckingCircuit { y: None }), + 1 => Self::Fifth(FifthRootCheckingCircuit { y: None }), + _ => unreachable!(), + } + } + + fn secondary_circuit(&self) -> TrivialSecondaryCircuit { + TrivialSecondaryCircuit::::default() + } +} + +fn test_nivc_nondet_with() +where + E1: Engine::Scalar>, + E2: Engine::Scalar>, +{ + let circuit_secondary = TrivialSecondaryCircuit::default(); + + let num_steps = 3; + + // produce non-deterministic hint + let (z0_primary, roots) = RootCheckingCircuit::new(num_steps); + assert_eq!(num_steps, roots.len()); + let z0_secondary = vec![::Scalar::ZERO]; + + // produce public parameters + let pp = PublicParams::< + E1, + E2, + RootCheckingCircuit<::Scalar>, + TrivialSecondaryCircuit<::Scalar>, + >::setup(&roots[0], &*default_ck_hint(), &*default_ck_hint()); + // produce a recursive SNARK + + let circuit_primary = &roots[0]; + + let mut recursive_snark = RecursiveSNARK::::new( + &pp, + circuit_primary, + circuit_primary, + &circuit_secondary, + &z0_primary, + &z0_secondary, + ) + .map_err(|err| { + print_constraints_name_on_error_index(&err, &pp, circuit_primary, &circuit_secondary, 2) + }) + .unwrap(); + + for circuit_primary in roots.iter().take(num_steps) { + let res = recursive_snark.prove_step(&pp, circuit_primary, &circuit_secondary); + assert!(res + .map_err(|err| { + print_constraints_name_on_error_index(&err, &pp, circuit_primary, &circuit_secondary, 2) + }) + .is_ok()); + + // verify the recursive SNARK + let res = recursive_snark + .verify(&pp, &z0_primary, &z0_secondary) + .map_err(|err| { + print_constraints_name_on_error_index(&err, &pp, circuit_primary, &circuit_secondary, 2) + }); + assert!(res.is_ok()); + } +} + +#[test] +fn test_nivc_nondet() { + test_nivc_nondet_with::(); + test_nivc_nondet_with::(); + test_nivc_nondet_with::(); +} diff --git a/src/supernova/utils.rs b/src/supernova/utils.rs new file mode 100644 index 000000000..3c390f433 --- /dev/null +++ b/src/supernova/utils.rs @@ -0,0 +1,179 @@ +use bellpepper_core::{ + boolean::{AllocatedBit, Boolean}, + num::AllocatedNum, + ConstraintSystem, LinearCombination, SynthesisError, +}; +use ff::PrimeField; +use itertools::Itertools as _; + +use crate::{ + gadgets::r1cs::{conditionally_select_alloc_relaxed_r1cs, AllocatedRelaxedR1CSInstance}, + traits::Engine, +}; + +/// Return the element of `a` given by the indicator bit in `selector_vec`. +/// +/// This function assumes `selector_vec` has been properly constrained", i.e. that exactly one entry is equal to 1. +// +// NOTE: When `a` is greater than 5 (estimated), it will be cheaper to use a multicase gadget. +// +// We should plan to rely on a well-designed gadget offering a common interface but that adapts its implementation based +// on the size of inputs (known at synthesis time). The threshold size depends on the size of the elements of `a`. The +// larger the elements, the fewer are needed before multicase becomes cost-effective. +pub fn get_from_vec_alloc_relaxed_r1cs::Base>>( + mut cs: CS, + a: &[AllocatedRelaxedR1CSInstance], + selector_vec: &[Boolean], +) -> Result, SynthesisError> { + assert_eq!(a.len(), selector_vec.len()); + + // Compare all instances in `a` to the first one + let first: AllocatedRelaxedR1CSInstance = a + .get(0) + .cloned() + .ok_or_else(|| SynthesisError::IncompatibleLengthVector("empty vec length".to_string()))?; + + // Since `selector_vec` is correct, only one entry is 1. + // If selector_vec[0] is 1, then all `conditionally_select` will return `first`. + // Otherwise, the correct instance will be selected. + let selected = a + .iter() + .zip_eq(selector_vec.iter()) + .enumerate() + .skip(1) + .try_fold(first, |matched, (i, (candidate, equal_bit))| { + conditionally_select_alloc_relaxed_r1cs( + cs.namespace(|| format!("next_matched_allocated-{:?}", i)), + candidate, + &matched, + equal_bit, + ) + })?; + + Ok(selected) +} + +/// Compute a selector vector `s` of size `num_indices`, such that +/// `s[i] == 1` if i == `target_index` and 0 otherwise. +pub fn get_selector_vec_from_index>( + mut cs: CS, + target_index: &AllocatedNum, + num_indices: usize, +) -> Result, SynthesisError> { + assert_ne!(num_indices, 0); + + // Compute the selector vector non-deterministically + let selector = (0..num_indices) + .map(|idx| { + // b <- idx == target_index + Ok(Boolean::Is(AllocatedBit::alloc( + cs.namespace(|| format!("allocate s_{:?}", idx)), + target_index.get_value().map(|v| v == F::from(idx as u64)), + )?)) + }) + .collect::, SynthesisError>>()?; + + // Enforce ∑ selector[i] = 1 + { + let selected_sum = selector.iter().fold(LinearCombination::zero(), |lc, bit| { + lc + &bit.lc(CS::one(), F::ONE) + }); + cs.enforce( + || "exactly-one-selection", + |_| selected_sum, + |lc| lc + CS::one(), + |lc| lc + CS::one(), + ); + } + + // Enforce `target_index - ∑ i * selector[i] = 0`` + { + let selected_value = selector + .iter() + .enumerate() + .fold(LinearCombination::zero(), |lc, (i, bit)| { + lc + &bit.lc(CS::one(), F::from(i as u64)) + }); + cs.enforce( + || "target_index - ∑ i * selector[i] = 0", + |lc| lc, + |lc| lc, + |lc| lc + target_index.get_variable() - &selected_value, + ); + } + + Ok(selector) +} + +#[cfg(test)] +mod test { + use crate::provider::PallasEngine; + + use super::*; + use bellpepper_core::test_cs::TestConstraintSystem; + use pasta_curves::pallas::Base; + + #[test] + fn test_get_from_vec_alloc_relaxed_r1cs_bounds() { + let n = 3; + for selected in 0..(2 * n) { + let mut cs = TestConstraintSystem::::new(); + + let allocated_target = AllocatedNum::alloc_infallible(&mut cs.namespace(|| "target"), || { + Base::from(selected as u64) + }); + + let selector_vec = get_selector_vec_from_index(&mut cs, &allocated_target, n).unwrap(); + + let vec = (0..n) + .map(|i| { + AllocatedRelaxedR1CSInstance::::default( + &mut cs.namespace(|| format!("elt-{i}")), + 4, + 64, + ) + .unwrap() + }) + .collect::>(); + + get_from_vec_alloc_relaxed_r1cs(&mut cs.namespace(|| "test-fn"), &vec, &selector_vec) + .unwrap(); + + if selected < n { + assert!(cs.is_satisfied()) + } else { + // If selected is out of range, the circuit must be unsatisfied. + assert!(!cs.is_satisfied()) + } + } + } + + #[test] + fn test_get_selector() { + for n in 1..4 { + for selected in 0..(2 * n) { + let mut cs = TestConstraintSystem::::new(); + + let allocated_target = + AllocatedNum::alloc_infallible(&mut cs.namespace(|| "target"), || { + Base::from(selected as u64) + }); + + let selector_vec = get_selector_vec_from_index(&mut cs, &allocated_target, n).unwrap(); + + if selected < n { + // Check that the selector bits are correct + assert_eq!(selector_vec.len(), n); + for (i, bit) in selector_vec.iter().enumerate() { + assert_eq!(bit.get_value().unwrap(), i == selected); + } + + assert!(cs.is_satisfied()); + } else { + // If selected is out of range, the circuit must be unsatisfied. + assert!(!cs.is_satisfied()); + } + } + } + } +} diff --git a/src/traits/circuit_supernova.rs b/src/traits/circuit_supernova.rs new file mode 100644 index 000000000..4ea7b90e2 --- /dev/null +++ b/src/traits/circuit_supernova.rs @@ -0,0 +1,115 @@ +//! This module defines traits that a supernova step function must implement +use bellpepper_core::{num::AllocatedNum, ConstraintSystem, SynthesisError}; +use core::marker::PhantomData; +use ff::PrimeField; + +/// A helper trait for a step of the incremental computation for `SuperNova` (i.e., circuit for F) -- to be implemented by +/// applications. +pub trait StepCircuit: Send + Sync + Clone { + /// Return the the number of inputs or outputs of each step + /// (this method is called only at circuit synthesis time) + /// `synthesize` and `output` methods are expected to take as + /// input a vector of size equal to arity and output a vector of size equal to arity + fn arity(&self) -> usize; + + /// Return this `StepCircuit`'s assigned index, for use when enforcing the program counter. + fn circuit_index(&self) -> usize; + + /// Synthesize the circuit for a computation step and return variable + /// that corresponds to the output of the step `pc_{i+1}` and `z_{i+1}` + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError>; +} + +/// A helper trait for a step of the incremental computation for `SuperNova` (i.e., circuit for F) -- automatically +/// implemented for `StepCircuit` and used internally to enforce that the circuit selected by the program counter is used +/// at each step. +pub trait EnforcingStepCircuit: Send + Sync + Clone + StepCircuit { + /// Delegate synthesis to `StepCircuit::synthesize`, and additionally, enforce the constraint that program counter + /// `pc`, if supplied, is equal to the circuit's assigned index. + fn enforcing_synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + if let Some(pc) = pc { + let circuit_index = F::from(self.circuit_index() as u64); + + // pc * 1 = circuit_index + cs.enforce( + || "pc matches circuit index", + |lc| lc + pc.get_variable(), + |lc| lc + CS::one(), + |lc| lc + (circuit_index, CS::one()), + ); + } + self.synthesize(cs, pc, z) + } +} + +impl> EnforcingStepCircuit for S {} + +/// A trivial step circuit that simply returns the input +#[derive(Clone, Debug, Default)] +pub struct TrivialTestCircuit { + _p: PhantomData, +} + +impl StepCircuit for TrivialTestCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + 0 + } + + fn synthesize>( + &self, + _cs: &mut CS, + program_counter: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + Ok((program_counter.cloned(), z.to_vec())) + } +} + +/// A trivial step circuit that simply returns the input, for use on the secondary circuit when implementing NIVC. +/// NOTE: This should not be needed. The secondary circuit doesn't need the program counter at all. +/// Ideally, the need this fills could be met by `traits::circuit::TrivialTestCircuit` (or equivalent). +#[derive(Clone, Debug, Default)] +pub struct TrivialSecondaryCircuit { + _p: PhantomData, +} + +impl StepCircuit for TrivialSecondaryCircuit +where + F: PrimeField, +{ + fn arity(&self) -> usize { + 1 + } + + fn circuit_index(&self) -> usize { + 0 + } + + fn synthesize>( + &self, + _cs: &mut CS, + program_counter: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + assert!(program_counter.is_none()); + assert_eq!(z.len(), 1, "Arity of trivial step circuit should be 1"); + Ok((None, z.to_vec())) + } +} diff --git a/src/traits/mod.rs b/src/traits/mod.rs index f26572eed..29cab88cd 100644 --- a/src/traits/mod.rs +++ b/src/traits/mod.rs @@ -145,5 +145,6 @@ impl> TranscriptReprTrait for &[T] { } pub mod circuit; +pub mod circuit_supernova; pub mod evaluation; pub mod snark; diff --git a/src/traits/snark.rs b/src/traits/snark.rs index b162185e9..99af972a5 100644 --- a/src/traits/snark.rs +++ b/src/traits/snark.rs @@ -55,6 +55,47 @@ pub trait RelaxedR1CSSNARKTrait: fn verify(&self, vk: &Self::VerifierKey, U: &RelaxedR1CSInstance) -> Result<(), NovaError>; } +/// A trait that defines the behavior of a `zkSNARK` to prove knowledge of satisfying witness to batches of relaxed R1CS instances. +pub trait BatchedRelaxedR1CSSNARKTrait: + Send + Sync + Serialize + for<'de> Deserialize<'de> +{ + /// A type that represents the prover's key + type ProverKey: Send + Sync + Serialize + for<'de> Deserialize<'de>; + + /// A type that represents the verifier's key + type VerifierKey: Send + + Sync + + Serialize + + for<'de> Deserialize<'de> + + DigestHelperTrait; + + /// This associated function (not a method) provides a hint that offers + /// a minimum sizing cue for the commitment key used by this SNARK + /// implementation. The commitment key passed in setup should then + /// be at least as large as this hint. + fn ck_floor() -> Box Fn(&'a R1CSShape) -> usize> { + default_ck_hint() + } + + /// Produces the keys for the prover and the verifier + fn setup( + ck: &CommitmentKey, + S: Vec<&R1CSShape>, + ) -> Result<(Self::ProverKey, Self::VerifierKey), NovaError>; + + /// Produces a new SNARK for a batch of relaxed R1CS + fn prove( + ck: &CommitmentKey, + pk: &Self::ProverKey, + S: Vec<&R1CSShape>, + U: &[RelaxedR1CSInstance], + W: &[RelaxedR1CSWitness], + ) -> Result; + + /// Verifies a SNARK for a batch of relaxed R1CS + fn verify(&self, vk: &Self::VerifierKey, U: &[RelaxedR1CSInstance]) -> Result<(), NovaError>; +} + /// A helper trait that defines the behavior of a verifier key of `zkSNARK` pub trait DigestHelperTrait { /// Returns the digest of the verifier's key