Skip to content

Commit

Permalink
chore: rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Al-Kindi-0 committed Aug 20, 2024
1 parent 3afec73 commit ead8aa6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 31 deletions.
4 changes: 2 additions & 2 deletions air/src/air/logup_gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField, ToEl
use super::EvaluationFrame;

/// A trait containing the necessary information in order to run the LogUp-GKR protocol of [1].
///
///
/// The trait contains useful information for running the GKR protocol as well as for implementing
/// the univariate IOP for multi-linear evaluation of Section 5 in [1] for the final evaluation
/// check resulting from GKR.
///
///
/// [1]: https://eprint.iacr.org/2023/1284
pub trait LogUpGkrEvaluator: Clone + Sync {
/// Defines the base field of the evaluator.
Expand Down
8 changes: 6 additions & 2 deletions prover/src/logup_gkr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ impl<E: FieldElement> EvaluatedCircuit<E> {
/// Note that the return type is a slice of [`CircuitLayerPolys`] as opposed to
/// [`CircuitLayer`], since the evaluated layers are stored in a representation which can be
/// proved using GKR.
pub fn layers(&self) -> &[CircuitLayerPolys<E>] {
&self.layer_polys
pub fn layers(self) -> Vec<CircuitLayerPolys<E>> {
self.layer_polys
}

/// Returns the numerator/denominator polynomials representing the output layer of the circuit.
Expand Down Expand Up @@ -196,6 +196,10 @@ where
denominators: MultiLinearPoly::from_evaluations(denominators),
}
}

fn into_numerators_denominators(self) -> (MultiLinearPoly<E>, MultiLinearPoly<E>) {
(self.numerators, self.denominators)
}
}

impl<E> Serializable for CircuitLayerPolys<E>
Expand Down
45 changes: 18 additions & 27 deletions prover/src/logup_gkr/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,20 @@ pub fn prove_gkr<E: FieldElement>(
}

// evaluate the GKR fractional sum circuit
let mut circuit = EvaluatedCircuit::new(main_trace, evaluator, &logup_randomness)?;
let circuit = EvaluatedCircuit::new(main_trace, evaluator, &logup_randomness)?;

// include the circuit output as part of the final proof
let CircuitLayerPolys { numerators, denominators } = circuit.output_layer().clone();

// run the GKR prover for all layers except the input layer
let (before_final_layer_proofs, gkr_claim) =
prove_intermediate_layers(&mut circuit, public_coin)?;
let (before_final_layer_proofs, gkr_claim) = prove_intermediate_layers(circuit, public_coin)?;

// build the MLEs of the relevant main trace columns
let mut main_trace_mls =
let main_trace_mls =
build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?;

let final_layer_proof =
prove_input_layer(evaluator, logup_randomness, &mut main_trace_mls, gkr_claim, public_coin)?;

// include the circuit output as part of the final proof
let CircuitLayerPolys { numerators, denominators } = circuit.output_layer().clone();
prove_input_layer(evaluator, logup_randomness, main_trace_mls, gkr_claim, public_coin)?;

Ok(GkrCircuitProof {
circuit_outputs: CircuitOutput { numerators, denominators },
Expand All @@ -97,7 +96,7 @@ fn prove_input_layer<
>(
evaluator: &impl LogUpGkrEvaluator<BaseField = E::BaseField>,
log_up_randomness: Vec<E>,
multi_linear_ext_polys: &mut[MultiLinearPoly<E>],
multi_linear_ext_polys: Vec<MultiLinearPoly<E>>,
claim: GkrClaim<E>,
transcript: &mut C,
) -> Result<FinalLayerProof<E>, GkrProverError> {
Expand Down Expand Up @@ -159,7 +158,7 @@ fn prove_intermediate_layers<
C: RandomCoin<Hasher = H, BaseField = E::BaseField>,
H: ElementHasher<BaseField = E::BaseField>,
>(
circuit: &mut EvaluatedCircuit<E>,
circuit: EvaluatedCircuit<E>,
transcript: &mut C,
) -> Result<(BeforeFinalLayerProof<E>, GkrClaim<E>), GkrProverError> {
// absorb the circuit output layer. This corresponds to sending the four values of the output
Expand All @@ -184,22 +183,17 @@ fn prove_intermediate_layers<
// loop over all inner layers in order to iteratively reduce a layer in terms of its successor
// layer. Note that we don't include the input layer, since its predecessor layer will be
// reduced in terms of the input layer separately in `prove_final_circuit_layer`.
for inner_layer in circuit.layers().iter().skip(1).rev().skip(1) {
for inner_layer in circuit.layers().into_iter().skip(1).rev().skip(1) {
// construct the Lagrange kernel evaluated at the previous GKR round randomness
let mut eq_mle = EqFunction::ml_at(evaluation_point.into());

// construct the vector of multi-linear polynomials
// TODO: avoid unnecessary allocation
//let (mut left_numerators, mut right_numerators) =
//inner_layer.numerators.project_least_significant_variable();
//let (mut left_denominators, mut right_denominators) =
//inner_layer.denominators.project_least_significant_variable();
let (numerators, denominators) = inner_layer.into_numerators_denominators();

// run the sumcheck protocol
let proof = sum_check_prove_num_rounds_degree_3(
claimed_evaluation,
&inner_layer.numerators,
&inner_layer.denominators,
numerators,
denominators,
&mut eq_mle,
transcript,
)?;
Expand Down Expand Up @@ -234,10 +228,7 @@ fn prove_intermediate_layers<

Ok((
BeforeFinalLayerProof { proof: layer_proofs },
GkrClaim {
evaluation_point,
claimed_evaluation,
},
GkrClaim { evaluation_point, claimed_evaluation },
))
}

Expand All @@ -249,17 +240,17 @@ fn sum_check_prove_num_rounds_degree_3<
H: ElementHasher<BaseField = E::BaseField>,
>(
claim: (E, E),
p: & MultiLinearPoly<E>,
q: & MultiLinearPoly<E>,
p: MultiLinearPoly<E>,
q: MultiLinearPoly<E>,
eq: &mut MultiLinearPoly<E>,
transcript: &mut C,
) -> Result<SumCheckProof<E>, GkrProverError> {
// generate challenge to batch two sumchecks
transcript.reseed(H::hash_elements(&[claim.0, claim.1]));
let r_batch = transcript.draw().map_err(|_| GkrProverError::FailedToGenerateChallenge)?;
let claim_ = claim.0 + claim.1 * r_batch;
let claim = claim.0 + claim.1 * r_batch;

let proof = sumcheck_prove_plain(claim_, r_batch, p, q, eq, transcript)?;
let proof = sumcheck_prove_plain(claim, r_batch, p, q, eq, transcript)?;

Ok(proof)
}

0 comments on commit ead8aa6

Please sign in to comment.