Skip to content

Commit

Permalink
Parallelize SHPLONK multi-open prover (#114)
Browse files Browse the repository at this point in the history
* feat: parallelize (cpu) shplonk prover

* shplonk: improve `construct_intermediate_sets` using `BTreeSet` and
`BTreeMap` more aggressively

* shplonk: add `Send` and `Sync` to `Query` trait for more parallelization

* fix: ensure the order of the collection of rotation sets is independent
of the values of the opening points

Co-authored-by: Jonathan Wang <[email protected]>
  • Loading branch information
jonathanpwang and jonathanpwang authored Jan 10, 2023
1 parent 0af4611 commit b8e458e
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 72 deletions.
2 changes: 1 addition & 1 deletion halo2_proofs/src/poly/commitment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub trait ParamsProver<'params, C: CurveAffine>: Params<'params, C> {
pub trait ParamsVerifier<'params, C: CurveAffine>: Params<'params, C> {}

/// Multi scalar multiplication engine
pub trait MSM<C: CurveAffine>: Clone + Debug {
pub trait MSM<C: CurveAffine>: Clone + Debug + Send + Sync {
/// Add arbitrary term (the scalar and the point)
fn append_term(&mut self, scalar: C::Scalar, point: C::CurveExt);

Expand Down
84 changes: 37 additions & 47 deletions halo2_proofs/src/poly/kzg/multiopen/shplonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ use crate::{
poly::{query::Query, Coeff, Polynomial},
transcript::ChallengeScalar,
};

use rayon::prelude::*;
use std::{
collections::{btree_map::Entry, BTreeMap, BTreeSet},
collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet},
marker::PhantomData,
sync::Arc,
};

#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -49,7 +50,7 @@ struct RotationSet<F: FieldExt, T: PartialEq + Clone> {
#[derive(Debug, PartialEq)]
struct IntermediateSets<F: FieldExt, Q: Query<F>> {
rotation_sets: Vec<RotationSet<F, Q::Commitment>>,
super_point_set: Vec<F>,
super_point_set: BTreeSet<F>,
}

fn construct_intermediate_sets<F: FieldExt, I, Q: Query<F, Eval = F>>(
Expand All @@ -69,18 +70,8 @@ where
.get_eval()
};

// Order points according to their rotation
let mut rotation_point_map = BTreeMap::new();
for query in queries.clone() {
let point = rotation_point_map
.entry(query.get_point())
.or_insert_with(|| query.get_point());

// Assert rotation point matching consistency
assert_eq!(*point, query.get_point());
}
// All points appear in queries
let super_point_set: Vec<F> = rotation_point_map.values().cloned().collect();
// All points that appear in queries
let mut super_point_set = BTreeSet::new();

// Collect rotation sets for each commitment
// Example elements in the vector:
Expand All @@ -89,19 +80,21 @@ where
// (C_2, {r_2, r_3, r_4}),
// (C_3, {r_2, r_3, r_4}),
// ...
let mut commitment_rotation_set_map: Vec<(Q::Commitment, Vec<F>)> = vec![];
for query in queries.clone() {
let mut commitment_rotation_set_map: Vec<(Q::Commitment, BTreeSet<F>)> = vec![];
for query in queries.iter() {
let rotation = query.get_point();
if let Some(pos) = commitment_rotation_set_map
.iter()
.position(|(commitment, _)| *commitment == query.get_commitment())
super_point_set.insert(rotation);
if let Some(commitment_rotation_set) = commitment_rotation_set_map
.iter_mut()
.find(|(commitment, _)| *commitment == query.get_commitment())
{
let (_, rotation_set) = &mut commitment_rotation_set_map[pos];
if !rotation_set.contains(&rotation) {
rotation_set.push(rotation);
}
let (_, rotation_set) = commitment_rotation_set;
rotation_set.insert(rotation);
} else {
commitment_rotation_set_map.push((query.get_commitment(), vec![rotation]));
commitment_rotation_set_map.push((
query.get_commitment(),
BTreeSet::from_iter(std::iter::once(rotation)),
));
};
}

Expand All @@ -111,41 +104,38 @@ where
// {r_1, r_2, r_3} : [C_1]
// {r_2, r_3, r_4} : [C_2, C_3],
// ...
let mut rotation_set_commitment_map = Vec::<(Vec<_>, Vec<Q::Commitment>)>::new();
for (commitment, rotation_set) in commitment_rotation_set_map.iter() {
if let Some(pos) = rotation_set_commitment_map.iter().position(|(set, _)| {
BTreeSet::<F>::from_iter(set.iter().cloned())
== BTreeSet::<F>::from_iter(rotation_set.iter().cloned())
}) {
let (_, commitments) = &mut rotation_set_commitment_map[pos];
if !commitments.contains(commitment) {
commitments.push(*commitment);
}
// NOTE: we want to make the order of the collection of rotation sets independent of the opening points, to ease the verifier computation
let mut rotation_set_commitment_map: Vec<(BTreeSet<F>, Vec<Q::Commitment>)> = vec![];
for (commitment, rotation_set) in commitment_rotation_set_map.into_iter() {
if let Some(rotation_set_commitment) = rotation_set_commitment_map
.iter_mut()
.find(|(set, _)| set == &rotation_set)
{
let (_, commitments) = rotation_set_commitment;
commitments.push(commitment);
} else {
rotation_set_commitment_map.push((rotation_set.clone(), vec![*commitment]))
}
rotation_set_commitment_map.push((rotation_set, vec![commitment]));
};
}

let rotation_sets = rotation_set_commitment_map
.into_iter()
.into_par_iter()
.map(|(rotations, commitments)| {
let rotations_vec = rotations.iter().collect::<Vec<_>>();
let commitments: Vec<Commitment<F, Q::Commitment>> = commitments
.iter()
.into_par_iter()
.map(|commitment| {
let evals: Vec<F> = rotations
.iter()
.map(|rotation| get_eval(*commitment, *rotation))
let evals: Vec<F> = rotations_vec
.par_iter()
.map(|&&rotation| get_eval(commitment, rotation))
.collect();
Commitment((*commitment, evals))
Commitment((commitment, evals))
})
.collect();

RotationSet {
commitments,
points: rotations
.iter()
.map(|rotation| *rotation_point_map.get(rotation).unwrap())
.collect(),
points: rotations.into_iter().collect(),
}
})
.collect::<Vec<RotationSet<_, _>>>();
Expand Down
51 changes: 29 additions & 22 deletions halo2_proofs/src/poly/kzg/multiopen/shplonk/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use ff::Field;
use group::Curve;
use halo2curves::pairing::Engine;
use rand_core::RngCore;
use rayon::prelude::*;
use std::fmt::Debug;
use std::io::{self, Write};
use std::marker::PhantomData;
Expand All @@ -36,8 +37,8 @@ struct CommitmentExtension<'a, C: CurveAffine> {
}

impl<'a, C: CurveAffine> Commitment<C::Scalar, PolynomialPointer<'a, C>> {
fn extend(&self, points: Vec<C::Scalar>) -> CommitmentExtension<'a, C> {
let poly = lagrange_interpolate(&points[..], &self.evals()[..]);
fn extend(&self, points: &[C::Scalar]) -> CommitmentExtension<'a, C> {
let poly = lagrange_interpolate(points, &self.evals()[..]);

let low_degree_equivalent = Polynomial {
values: poly,
Expand Down Expand Up @@ -79,10 +80,10 @@ struct RotationSetExtension<'a, C: CurveAffine> {
}

impl<'a, C: CurveAffine> RotationSet<C::Scalar, PolynomialPointer<'a, C>> {
fn extend(&self, commitments: Vec<CommitmentExtension<'a, C>>) -> RotationSetExtension<'a, C> {
fn extend(self, commitments: Vec<CommitmentExtension<'a, C>>) -> RotationSetExtension<'a, C> {
RotationSetExtension {
commitments,
points: self.points.clone(),
points: self.points,
}
}
}
Expand Down Expand Up @@ -136,15 +137,17 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
// [P_i_0(X) - R_i_0(X), P_i_1(X) - R_i_1(X), ... ]
let numerators = rotation_set
.commitments
.iter()
.map(|commitment| commitment.quotient_contribution());
.par_iter()
.map(|commitment| commitment.quotient_contribution())
.collect::<Vec<_>>();

// define numerator polynomial as
// N_i_j(X) = (P_i_j(X) - R_i_j(X))
// and combine polynomials with same evaluation point set
// N_i(X) = linear_combinination(y, N_i_j(X))
// where y is random scalar to combine numerator polynomials
let n_x = numerators
.into_iter()
.zip(powers(*y))
.map(|(numerator, power_of_y)| numerator * power_of_y)
.reduce(|acc, numerator| acc + &numerator)
Expand All @@ -171,22 +174,26 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
);

let rotation_sets: Vec<RotationSetExtension<E::G1Affine>> = rotation_sets
.iter()
.into_par_iter()
.map(|rotation_set| {
let commitments: Vec<CommitmentExtension<E::G1Affine>> = rotation_set
.commitments
.iter()
.map(|commitment_data| commitment_data.extend(rotation_set.points.clone()))
.par_iter()
.map(|commitment_data| commitment_data.extend(&rotation_set.points))
.collect();
rotation_set.extend(commitments)
})
.collect();

let v: ChallengeV<_> = transcript.squeeze_challenge_scalar();

let quotient_polynomials = rotation_sets.iter().map(quotient_contribution);
let quotient_polynomials = rotation_sets
.par_iter()
.map(quotient_contribution)
.collect::<Vec<_>>();

let h_x: Polynomial<E::Scalar, Coeff> = quotient_polynomials
.into_iter()
.zip(powers(*v))
.map(|(poly, power_of_v)| poly * power_of_v)
.reduce(|acc, poly| acc + &poly)
Expand All @@ -196,18 +203,15 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
transcript.write_point(h)?;
let u: ChallengeU<_> = transcript.squeeze_challenge_scalar();

let zt_eval = evaluate_vanishing_polynomial(&super_point_set[..], *u);

let linearisation_contribution =
|rotation_set: RotationSetExtension<E::G1Affine>| -> (Polynomial<E::Scalar, Coeff>, E::Scalar) {
let diffs: Vec<E::Scalar> = super_point_set
.iter()
.filter(|point| !rotation_set.points.contains(point))
.copied()
.collect();
let mut diffs = super_point_set.clone();
for point in rotation_set.points.iter() {
diffs.remove(point);
}
let diffs = diffs.into_iter().collect::<Vec<_>>();

// calculate difference vanishing polynomial evaluation

let z_i = evaluate_vanishing_polynomial(&diffs[..], *u);

// inner linearisation contibutions are
Expand All @@ -216,15 +220,15 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
// where u is random evaluation point
let inner_contributions = rotation_set
.commitments
.iter()
.map(|commitment| commitment.linearisation_contribution(*u));
.par_iter()
.map(|commitment| commitment.linearisation_contribution(*u)).collect::<Vec<_>>();

// define inner contributor polynomial as
// L_i_j(X) = (P_i_j(X) - r_i_j)
// and combine polynomials with same evaluation point set
// L_i(X) = linear_combinination(y, L_i_j(X))
// where y is random scalar to combine inner contibutors
let l_x: Polynomial<E::Scalar, Coeff> = inner_contributions.zip(powers(*y)).map(|(poly, power_of_y)| poly * power_of_y).reduce(|acc, poly| acc + &poly).unwrap();
let l_x: Polynomial<E::Scalar, Coeff> = inner_contributions.into_iter().zip(powers(*y)).map(|(poly, power_of_y)| poly * power_of_y).reduce(|acc, poly| acc + &poly).unwrap();

// finally scale l_x by difference vanishing polynomial evaluation z_i
(l_x * z_i, z_i)
Expand All @@ -235,7 +239,7 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
Vec<Polynomial<E::Scalar, Coeff>>,
Vec<E::Scalar>,
) = rotation_sets
.into_iter()
.into_par_iter()
.map(linearisation_contribution)
.unzip();

Expand All @@ -246,9 +250,12 @@ impl<'params, E: Engine + Debug> Prover<'params, KZGCommitmentScheme<E>>
.reduce(|acc, poly| acc + &poly)
.unwrap();

let super_point_set = super_point_set.into_iter().collect::<Vec<_>>();
let zt_eval = evaluate_vanishing_polynomial(&super_point_set[..], *u);
let l_x = l_x - &(h_x * zt_eval);

// sanity check
#[cfg(debug_assertions)]
{
let must_be_zero = eval_polynomial(&l_x.values[..], *u);
assert_eq!(must_be_zero, E::Scalar::zero());
Expand Down
4 changes: 2 additions & 2 deletions halo2_proofs/src/poly/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use crate::{
use ff::Field;
use halo2curves::CurveAffine;

pub trait Query<F>: Sized + Clone {
type Commitment: PartialEq + Copy;
pub trait Query<F>: Sized + Clone + Send + Sync {
type Commitment: PartialEq + Copy + Send + Sync;
type Eval: Clone + Default + Debug;

fn get_point(&self) -> F;
Expand Down

0 comments on commit b8e458e

Please sign in to comment.