diff --git a/examples/src/fibonacci/fib2/air.rs b/examples/src/fibonacci/fib2/air.rs index 150d8a777..2689a48f8 100644 --- a/examples/src/fibonacci/fib2/air.rs +++ b/examples/src/fibonacci/fib2/air.rs @@ -95,7 +95,7 @@ impl air::LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/fibonacci/fib8/air.rs b/examples/src/fibonacci/fib8/air.rs index a9573634b..765f3980d 100644 --- a/examples/src/fibonacci/fib8/air.rs +++ b/examples/src/fibonacci/fib8/air.rs @@ -103,7 +103,7 @@ impl air::LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/fibonacci/fib_small/air.rs b/examples/src/fibonacci/fib_small/air.rs index c567394ed..e32330a9b 100644 --- a/examples/src/fibonacci/fib_small/air.rs +++ b/examples/src/fibonacci/fib_small/air.rs @@ -95,7 +95,7 @@ impl air::LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/fibonacci/mulfib2/air.rs b/examples/src/fibonacci/mulfib2/air.rs index ba9427443..323ed623f 100644 --- a/examples/src/fibonacci/mulfib2/air.rs +++ b/examples/src/fibonacci/mulfib2/air.rs @@ -97,7 +97,7 @@ impl air::LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/fibonacci/mulfib8/air.rs b/examples/src/fibonacci/mulfib8/air.rs index f7b04b0ae..2c63fe907 100644 --- a/examples/src/fibonacci/mulfib8/air.rs +++ b/examples/src/fibonacci/mulfib8/air.rs @@ -118,7 +118,7 @@ impl air::LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/lamport/aggregate/air.rs b/examples/src/lamport/aggregate/air.rs index 58331527e..71cefe6ac 100644 --- a/examples/src/lamport/aggregate/air.rs +++ b/examples/src/lamport/aggregate/air.rs @@ -316,7 +316,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/lamport/threshold/air.rs b/examples/src/lamport/threshold/air.rs index afac4f8c3..be93d0295 100644 --- a/examples/src/lamport/threshold/air.rs +++ b/examples/src/lamport/threshold/air.rs @@ -392,7 +392,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/merkle/air.rs b/examples/src/merkle/air.rs index e605f7ffe..138bddf03 100644 --- a/examples/src/merkle/air.rs +++ b/examples/src/merkle/air.rs @@ -162,7 +162,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/rescue/air.rs b/examples/src/rescue/air.rs index 06a994c64..42c8e772f 100644 --- a/examples/src/rescue/air.rs +++ b/examples/src/rescue/air.rs @@ -167,7 +167,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/rescue_raps/air.rs b/examples/src/rescue_raps/air.rs index a2f2e5312..e51f6f42f 100644 --- a/examples/src/rescue_raps/air.rs +++ b/examples/src/rescue_raps/air.rs @@ -297,7 +297,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/vdf/exempt/air.rs b/examples/src/vdf/exempt/air.rs index 2152d632f..b02258b20 100644 --- a/examples/src/vdf/exempt/air.rs +++ b/examples/src/vdf/exempt/air.rs @@ -109,7 +109,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/examples/src/vdf/regular/air.rs b/examples/src/vdf/regular/air.rs index 6e1e2fc32..1a7ed8827 100644 --- a/examples/src/vdf/regular/air.rs +++ b/examples/src/vdf/regular/air.rs @@ -100,7 +100,7 @@ impl LogUpGkrEvaluator for PlainLogUpGkrEval { unimplemented!() } - fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E]) -> Vec + fn build_query(&self, _frame: &EvaluationFrame, _periodic_values: &[E], _query: &mut [E]) where E: FieldElement, { diff --git a/prover/src/logup_gkr/prover.rs b/prover/src/logup_gkr/prover.rs index e6dde54f3..f95597825 100644 --- a/prover/src/logup_gkr/prover.rs +++ b/prover/src/logup_gkr/prover.rs @@ -73,11 +73,11 @@ pub fn prove_gkr( prove_intermediate_layers(&mut circuit, public_coin)?; // build the MLEs of the relevant main trace columns - let main_trace_mls = + let mut main_trace_mls = build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?; let final_layer_proof = - prove_input_layer(evaluator, logup_randomness, main_trace_mls, gkr_claim, public_coin)?; + prove_input_layer(evaluator, logup_randomness, &mut main_trace_mls, gkr_claim, public_coin)?; // include the circuit output as part of the final proof let CircuitLayerPolys { numerators, denominators } = circuit.output_layer().clone(); @@ -97,12 +97,12 @@ fn prove_input_layer< >( evaluator: &impl LogUpGkrEvaluator, log_up_randomness: Vec, - mut mls: Vec>, - gkr_claim: GkrClaim, + multi_linear_ext_polys: &mut[MultiLinearPoly], + claim: GkrClaim, transcript: &mut C, ) -> Result, GkrProverError> { // parse the [GkrClaim] resulting from the previous GKR layer - let GkrClaim { evaluation_point, claimed_evaluation } = gkr_claim; + let GkrClaim { evaluation_point, claimed_evaluation } = claim; transcript.reseed(H::hash_elements(&[claimed_evaluation.0, claimed_evaluation.1])); let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; @@ -114,11 +114,11 @@ fn prove_input_layer< claim, r_batch, log_up_randomness, - &mut mls, + multi_linear_ext_polys, transcript, )?; - Ok(FinalLayerProof { proof }) + Ok(FinalLayerProof::new(proof)) } // TODO: Make the multi-linears over the base field and define an operation of folding with a challenge @@ -172,10 +172,10 @@ fn prove_intermediate_layers< // generate the challenge and reduce [p0, p1, q0, q1] to [pr, qr] let r = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; - let mut claim = circuit.evaluate_output_layer(r); + let mut claimed_evaluation = circuit.evaluate_output_layer(r); - let mut proof_layers: Vec> = Vec::new(); - let mut rand = vec![r]; + let mut layer_proofs: Vec> = Vec::new(); + let mut evaluation_point = vec![r]; // Loop over all inner layers, from output to input. // @@ -186,7 +186,7 @@ fn prove_intermediate_layers< // reduced in terms of the input layer separately in `prove_final_circuit_layer`. for inner_layer in circuit.layers().iter().skip(1).rev().skip(1) { // construct the Lagrange kernel evaluated at the previous GKR round randomness - let mut poly_x = EqFunction::ml_at(rand.into()); + let mut eq_mle = EqFunction::ml_at(evaluation_point.into()); // construct the vector of multi-linear polynomials // TODO: avoid unnecessary allocation @@ -197,12 +197,12 @@ fn prove_intermediate_layers< // run the sumcheck protocol let proof = sum_check_prove_num_rounds_degree_3( - claim, + claimed_evaluation, &mut left_numerators, &mut right_numerators, &mut left_denominators, &mut right_denominators, - &mut poly_x, + &mut eq_mle, transcript, )?; @@ -211,7 +211,7 @@ fn prove_intermediate_layers< let r_layer = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?; // reduce the claim - claim = { + claimed_evaluation = { let left_numerators_opening = proof.openings_claim.openings[0]; let right_numerators_opening = proof.openings_claim.openings[1]; let left_denominators_opening = proof.openings_claim.openings[2]; @@ -229,16 +229,16 @@ fn prove_intermediate_layers< // collect the randomness used for the current layer let mut ext = vec![r_layer]; ext.extend_from_slice(&proof.openings_claim.eval_point); - rand = ext; + evaluation_point = ext; - proof_layers.push(proof); + layer_proofs.push(proof); } Ok(( - BeforeFinalLayerProof { proof: proof_layers }, + BeforeFinalLayerProof { proof: layer_proofs }, GkrClaim { - evaluation_point: rand, - claimed_evaluation: claim, + evaluation_point, + claimed_evaluation, }, )) } diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index f30db974c..5beac9bb9 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -124,8 +124,12 @@ where /// A proof for the input circuit layer i.e., the final layer in the GKR protocol. #[derive(Debug, Clone)] -pub struct FinalLayerProof { - pub proof: SumCheckProof, +pub struct FinalLayerProof(SumCheckProof); + +impl FinalLayerProof { + pub fn new(proof: SumCheckProof) -> Self { + Self(proof) + } } impl Serializable for FinalLayerProof @@ -133,8 +137,7 @@ where E: FieldElement, { fn write_into(&self, target: &mut W) { - let Self { proof } = self; - proof.write_into(target); + self.0.write_into(target); } } @@ -143,9 +146,7 @@ where E: FieldElement, { fn read_from(source: &mut R) -> Result { - Ok(Self { - proof: Deserializable::read_from(source)?, - }) + Ok(Self(Deserializable::read_from(source)?)) } } @@ -170,7 +171,7 @@ pub struct GkrCircuitProof { impl GkrCircuitProof { pub fn get_final_opening_claim(&self) -> FinalOpeningClaim { - self.final_layer_proof.proof.openings_claim.clone() + self.final_layer_proof.0.openings_claim.clone() } } @@ -181,7 +182,7 @@ where fn write_into(&self, target: &mut W) { self.circuit_outputs.write_into(target); self.before_final_layer_proofs.write_into(target); - self.final_layer_proof.proof.write_into(target); + self.final_layer_proof.0.write_into(target); } } diff --git a/sumcheck/src/verifier/mod.rs b/sumcheck/src/verifier/mod.rs index d1cfae3a4..5cc338bb7 100644 --- a/sumcheck/src/verifier/mod.rs +++ b/sumcheck/src/verifier/mod.rs @@ -64,8 +64,6 @@ pub fn verify_sum_check_input_layer, ) -> Result, SumCheckVerifierError> { - let FinalLayerProof { proof } = proof; - // generate challenge to batch sum-checks transcript.reseed(H::hash_elements(&[claim.0, claim.1])); let r_batch: E = transcript @@ -77,17 +75,17 @@ pub fn verify_sum_check_input_layer