Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for periodic columns in LogUp-GKR #304

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a11700d
feat: math utilities needed for sum-check protocol
Al-Kindi-0 Aug 6, 2024
5e00b98
feat: add sum-check prover and verifier
Al-Kindi-0 Aug 6, 2024
798deb1
tests: add sanity tests for utils
Al-Kindi-0 Aug 6, 2024
c1b67d6
feat: use SmallVec
Al-Kindi-0 Aug 7, 2024
657a24b
feat: add remaining functions for sum-check verifier
Al-Kindi-0 Aug 9, 2024
1fe8cc2
chore: move prover into sub-mod
Al-Kindi-0 Aug 9, 2024
486ea9d
chore: remove utils mod
Al-Kindi-0 Aug 9, 2024
68e6097
chore: move logup evaluator trait to separate file
Al-Kindi-0 Aug 9, 2024
09a231d
feat: add multi-threading support and simplify input sum-check
Al-Kindi-0 Aug 15, 2024
d1e899d
feat: add benchmarks and address feedback
Al-Kindi-0 Aug 19, 2024
f44ba93
feat: address feedback and add benchmarks
Al-Kindi-0 Aug 19, 2024
29d5241
feat: add GKR backend for LogUp-GKR
Al-Kindi-0 Aug 9, 2024
37afc96
chore: remove old way of handling Lagrange kernel
Al-Kindi-0 Aug 12, 2024
59ce28c
chore: remove GkrVerifier
Al-Kindi-0 Aug 12, 2024
f0fdfb2
chore: correct header
Al-Kindi-0 Aug 12, 2024
3583e7c
feat: simplify sum-check for input layer
Al-Kindi-0 Aug 14, 2024
d77985d
feat: add mult-threading to sum-checks and multi-linears
Al-Kindi-0 Aug 14, 2024
0933af9
chore: remove num_rounds from sum-check
Al-Kindi-0 Aug 14, 2024
e0650c1
chore: unify serial and concurrent EQ impls
Al-Kindi-0 Aug 14, 2024
12998fb
fix: formatting
Al-Kindi-0 Aug 14, 2024
8ba5ce9
fix: concurrent feature flag
Al-Kindi-0 Aug 14, 2024
f1511e2
feat: make query a mut ref in build_query
Al-Kindi-0 Aug 19, 2024
f4f2d6c
chore: rebase on sum-check
Al-Kindi-0 Aug 20, 2024
51a33ec
chore: update var names and simplify some structs
Al-Kindi-0 Aug 20, 2024
2ac0a9e
wip: avoid extra allocation when projecting mle
Al-Kindi-0 Aug 20, 2024
8e2f7e6
chore: rebase
Al-Kindi-0 Aug 20, 2024
7de00ae
chore: fix rebase against facebook:logup-gkr
Al-Kindi-0 Aug 21, 2024
7caf36f
chore: address feedback
Al-Kindi-0 Aug 21, 2024
f974dd1
chore: remove TODO
Al-Kindi-0 Aug 21, 2024
452559e
fix: tex formatting issue
Al-Kindi-0 Aug 22, 2024
1ae173c
chore: add dummy GkrLogUp evaluator
Al-Kindi-0 Aug 23, 2024
7e3f9e4
chore: fix clippy warnings
Al-Kindi-0 Aug 23, 2024
ff555de
feat: add support for periodic columns to LogUp-GKR
Al-Kindi-0 Aug 23, 2024
0f7e2e0
fix: bug in num_oracles
Al-Kindi-0 Aug 27, 2024
8bad002
Merge branch 'logup-gkr' into al-gkr-backend-for-logup-gkr-support-pe…
Al-Kindi-0 Aug 30, 2024
7e50a6e
fix: post merge issues
Al-Kindi-0 Aug 30, 2024
05f7d09
fix: nits
Al-Kindi-0 Aug 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions air/src/air/logup_gkr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,27 @@ pub trait LogUpGkrEvaluator: Clone + Sync {
self.get_oracles().to_vec(),
)
}

/// Returns the periodic values used in the LogUp-GKR statement, either as base field element
/// during circuit evaluation or as extension field element during the run of sum-check for
/// the input layer.
fn build_periodic_values<F, E>(&self) -> PeriodicTable<F>
where
F: FieldElement<BaseField = Self::BaseField>,
E: FieldElement<BaseField = Self::BaseField> + ExtensionOf<F>,
{
let mut table = Vec::new();

let oracles = self.get_oracles();

for oracle in oracles {
if let LogUpGkrOracle::PeriodicValue(values) = oracle {
let values = embed_in_extension(values.to_vec());
table.push(values)
}
}
PeriodicTable { table }
}
}

#[derive(Clone, Default)]
Expand Down Expand Up @@ -200,3 +221,76 @@ pub enum LogUpGkrOracle<B: StarkField> {
/// must be a power of 2.
PeriodicValue(Vec<B>),
}

// PERIODIC COLUMNS FOR LOGUP
// =================================================================================================

/// Stores the periodic columns used in a LogUp-GKR statement.
///
/// Each stored periodic column is interpreted as a multi-linear extension polynomial of the column
/// with the given periodic values. Due to the periodic nature of the values, storing, binding of
/// an argument and evaluating the said multi-linear extension can be all done linearly in the size
/// of the smallest cycle defining the periodic values. Hence we only store the values of this
/// smallest cycle. The cycle is assumed throughout to be a power of 2.
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)]
pub struct PeriodicTable<E: FieldElement> {
pub table: Vec<Vec<E>>,
}

impl<E> PeriodicTable<E>
where
E: FieldElement,
{
pub fn new<B>(table: Vec<Vec<B>>) -> Self
where
E: FieldElement + ExtensionOf<B>,
B: StarkField,
{
let mut result = vec![];
for col in table.iter() {
let mut res = vec![];
for v in col {
res.push(E::from(*v))
}
result.push(res)
}

Self { table: result }
}

pub fn num_columns(&self) -> usize {
self.table.len()
}

pub fn table(&self) -> &[Vec<E>] {
&self.table
}

pub fn get_periodic_values_at(&self, row: usize) -> Vec<E> {
self.table.iter().map(|col| col[row % col.len()]).collect()
}

pub fn bind_least_significant_variable(&mut self, round_challenge: E) {
for col in self.table.iter_mut() {
if col.len() > 1 {
let num_evals = col.len() >> 1;
for i in 0..num_evals {
col[i] = col[i << 1] + round_challenge * (col[(i << 1) + 1] - col[i << 1]);
}
col.truncate(num_evals)
}
}
}
}

// HELPER
// =================================================================================================

fn embed_in_extension<E: FieldElement>(values: Vec<E::BaseField>) -> Vec<E> {
let mut res = vec![];
for v in values {
res.push(E::from(v))
}

res
}
2 changes: 1 addition & 1 deletion air/src/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ pub use lagrange::{
};

mod logup_gkr;
pub use logup_gkr::{LogUpGkrEvaluator, LogUpGkrOracle};
pub use logup_gkr::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable};

mod coefficients;
pub use coefficients::{
Expand Down
1 change: 0 additions & 1 deletion air/src/air/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ impl MockAir {
impl Air for MockAir {
type BaseField = BaseElement;
type PublicInputs = ();
//type LogUpGkrEvaluator = DummyLogUpGkrEval<Self::BaseField, ()>;

fn new(trace_info: TraceInfo, _pub_inputs: (), _options: ProofOptions) -> Self {
let num_assertions = trace_info.meta()[0] as usize;
Expand Down
4 changes: 2 additions & 2 deletions air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ pub use air::{
DeepCompositionCoefficients, EvaluationFrame, GkrData,
LagrangeConstraintsCompositionCoefficients, LagrangeKernelBoundaryConstraint,
LagrangeKernelConstraints, LagrangeKernelEvaluationFrame, LagrangeKernelRandElements,
LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, TraceInfo,
TransitionConstraintDegree, TransitionConstraints,
LagrangeKernelTransitionConstraints, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable,
TraceInfo, TransitionConstraintDegree, TransitionConstraints,
};
5 changes: 3 additions & 2 deletions prover/src/logup_gkr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ impl<E: FieldElement> EvaluatedCircuit<E> {
log_up_randomness: &[E],
) -> CircuitLayer<E> {
let num_fractions = evaluator.get_num_fractions();
let periodic_values = evaluator.build_periodic_values::<E::BaseField, E>();
let mut input_layer_wires =
Vec::with_capacity(main_trace.main_segment().num_rows() * num_fractions);
let mut main_frame = EvaluationFrame::new(main_trace.main_segment().num_cols());
Expand All @@ -119,8 +120,8 @@ impl<E: FieldElement> EvaluatedCircuit<E> {
for i in 0..main_trace.main_segment().num_rows() {
let wires_from_trace_row = {
main_trace.read_main_frame(i, &mut main_frame);

evaluator.build_query(&main_frame, &[], &mut query);
let periodic_values_row = periodic_values.get_periodic_values_at(i);
evaluator.build_query(&main_frame, &periodic_values_row, &mut query);

evaluator.evaluate_query(
&query,
Expand Down
24 changes: 17 additions & 7 deletions prover/src/logup_gkr/prover.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use alloc::vec::Vec;

use air::{LogUpGkrEvaluator, LogUpGkrOracle};
use air::{LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable};
use crypto::{ElementHasher, RandomCoin};
use math::FieldElement;
use sumcheck::{
Expand Down Expand Up @@ -75,11 +75,18 @@ pub fn prove_gkr<E: FieldElement>(
let (before_final_layer_proofs, gkr_claim) = prove_intermediate_layers(circuit, public_coin)?;

// build the MLEs of the relevant main trace columns
let main_trace_mls =
let (main_trace_mls, mut periodic_table) =
build_mls_from_main_trace_segment(evaluator.get_oracles(), main_trace.main_segment())?;

let final_layer_proof =
prove_input_layer(evaluator, logup_randomness, main_trace_mls, gkr_claim, public_coin)?;
// run the GKR prover for the input layer
let final_layer_proof = prove_input_layer(
evaluator,
logup_randomness,
main_trace_mls,
&mut periodic_table,
gkr_claim,
public_coin,
)?;

Ok(GkrCircuitProof {
circuit_outputs: CircuitOutput { numerators, denominators },
Expand All @@ -97,6 +104,7 @@ fn prove_input_layer<
evaluator: &impl LogUpGkrEvaluator<BaseField = E::BaseField>,
log_up_randomness: Vec<E>,
multi_linear_ext_polys: Vec<MultiLinearPoly<E>>,
periodic_table: &mut PeriodicTable<E>,
claim: GkrClaim<E>,
transcript: &mut C,
) -> Result<FinalLayerProof<E>, GkrProverError> {
Expand All @@ -114,6 +122,7 @@ fn prove_input_layer<
r_batch,
log_up_randomness,
multi_linear_ext_polys,
periodic_table,
transcript,
)?;

Expand All @@ -125,8 +134,9 @@ fn prove_input_layer<
fn build_mls_from_main_trace_segment<E: FieldElement>(
oracles: &[LogUpGkrOracle<E::BaseField>],
main_trace: &ColMatrix<<E as FieldElement>::BaseField>,
) -> Result<Vec<MultiLinearPoly<E>>, GkrProverError> {
) -> Result<(Vec<MultiLinearPoly<E>>, PeriodicTable<E>), GkrProverError> {
let mut mls = vec![];
let mut periodic_values = vec![];

for oracle in oracles {
match oracle {
Expand All @@ -146,10 +156,10 @@ fn build_mls_from_main_trace_segment<E: FieldElement>(
let ml = MultiLinearPoly::from_evaluations(values);
mls.push(ml)
},
LogUpGkrOracle::PeriodicValue(_) => unimplemented!(),
LogUpGkrOracle::PeriodicValue(values) => periodic_values.push(values.to_vec()),
};
}
Ok(mls)
Ok((mls, PeriodicTable::new(periodic_values)))
}

/// Proves all GKR layers except for input layer.
Expand Down
16 changes: 13 additions & 3 deletions sumcheck/benches/sum_check_high_degree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use std::{marker::PhantomData, time::Duration};

use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle};
use air::{EvaluationFrame, LogUpGkrEvaluator, LogUpGkrOracle, PeriodicTable};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use crypto::{hashers::Blake3_192, DefaultRandomCoin, RandomCoin};
use math::{fields::f64::BaseElement, ExtensionOf, FieldElement, StarkField};
Expand Down Expand Up @@ -37,13 +37,14 @@ fn sum_check_high_degree(c: &mut Criterion) {
)
},
|(
(claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4)),
(claim, r_batch, rand_pt, (ml0, ml1, ml2, ml3, ml4), periodic_table),
evaluator,
logup_randomness,
transcript,
)| {
let mls = vec![ml0, ml1, ml2, ml3, ml4];
let mut transcript = transcript;
let mut periodic_table = periodic_table;

sum_check_prove_higher_degree(
&evaluator,
Expand All @@ -52,6 +53,7 @@ fn sum_check_high_degree(c: &mut Criterion) {
r_batch,
logup_randomness,
mls,
&mut periodic_table,
&mut transcript,
)
},
Expand All @@ -76,21 +78,29 @@ fn setup_sum_check<E: FieldElement>(
MultiLinearPoly<E>,
MultiLinearPoly<E>,
),
PeriodicTable<E>,
) {
let n = 1 << log_size;
let table = MultiLinearPoly::from_evaluations(rand_vector(n));
let multiplicity = MultiLinearPoly::from_evaluations(rand_vector(n));
let values_0 = MultiLinearPoly::from_evaluations(rand_vector(n));
let values_1 = MultiLinearPoly::from_evaluations(rand_vector(n));
let values_2 = MultiLinearPoly::from_evaluations(rand_vector(n));
let periodic_table = PeriodicTable::default();

// this will not generate the correct claim with overwhelming probability but should be fine
// for benchmarking
let rand_pt: Vec<E> = rand_vector(log_size + 2);
let r_batch: E = rand_value();
let claim: E = rand_value();

(claim, r_batch, rand_pt, (table, multiplicity, values_0, values_1, values_2))
(
claim,
r_batch,
rand_pt,
(table, multiplicity, values_0, values_1, values_2),
periodic_table,
)
}

#[derive(Clone, Default)]
Expand Down
Loading