Skip to content

Commit

Permalink
feat: pick changes from zcash#728 and changes of flag test-dev-graph
Browse files Browse the repository at this point in the history
  • Loading branch information
han0110 committed Aug 29, 2023
1 parent 888ec23 commit 1a905a9
Show file tree
Hide file tree
Showing 19 changed files with 138 additions and 41 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ minor version bump.
`halo2` currently uses [rayon](https://github.com/rayon-rs/rayon) for parallel computation.
The `RAYON_NUM_THREADS` environment variable can be used to set the number of threads.

You can disable `rayon` by disabling the `"multicore"` feature.
Warning! Halo2 will lose access to parallelism if you disable the `"multicore"` feature.
This will significantly degrade performance.

## License

Licensed under either of
Expand Down
2 changes: 1 addition & 1 deletion halo2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"]

[dependencies]
halo2_proofs = { version = "0.2", path = "../halo2_proofs" }
halo2_proofs = { version = "0.2", path = "../halo2_proofs", default-features = false }

[lib]
bench = false
11 changes: 8 additions & 3 deletions halo2_gadgets/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ arrayvec = "0.7.0"
bitvec = "1"
ff = { version = "0.13", features = ["bits"] }
group = "0.13"
halo2_proofs = { version = "0.2", path = "../halo2_proofs" }
halo2_proofs = { version = "0.2", path = "../halo2_proofs", default-features = false }
lazy_static = "1"
halo2curves = { version = "0.1.0" }
proptest = { version = "1.0.0", optional = true }
Expand All @@ -35,7 +35,7 @@ subtle = "2.3"
uint = "0.9.2" # MSRV 1.56.1

# Developer tooling dependencies
plotters = { version = "0.3.0", optional = true }
plotters = { version = "0.3.0", default-features = false, optional = true }

[dev-dependencies]
criterion = "0.3"
Expand All @@ -48,7 +48,12 @@ pprof = { version = "0.8", features = ["criterion", "flamegraph"] } # MSRV 1.56
bench = false

[features]
dev-graph = ["halo2_proofs/dev-graph", "plotters"]
test-dev-graph = [
"halo2_proofs/dev-graph",
"plotters",
"plotters/bitmap_backend",
"plotters/bitmap_encoder",
]
circuit-params = ["halo2_proofs/circuit-params"]
test-dependencies = ["proptest"]
unstable = []
Expand Down
2 changes: 1 addition & 1 deletion halo2_gadgets/src/ecc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ pub(crate) mod tests {
assert_eq!(prover.verify(), Ok(()))
}

#[cfg(feature = "dev-graph")]
#[cfg(feature = "test-dev-graph")]
#[test]
fn print_ecc_chip() {
use plotters::prelude::*;
Expand Down
2 changes: 1 addition & 1 deletion halo2_gadgets/src/poseidon/pow5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ mod tests {
}
}

#[cfg(feature = "dev-graph")]
#[cfg(feature = "test-dev-graph")]
#[test]
fn print_poseidon_chip() {
use plotters::prelude::*;
Expand Down
2 changes: 1 addition & 1 deletion halo2_gadgets/src/sha256/table16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ trait Table16Assignment {
}

#[cfg(test)]
#[cfg(feature = "dev-graph")]
#[cfg(feature = "test-dev-graph")]
mod tests {
use super::super::{Sha256, BLOCK_SIZE};
use super::{message_schedule::msg_schedule_test_input, Table16Chip, Table16Config};
Expand Down
2 changes: 1 addition & 1 deletion halo2_gadgets/src/sinsemilla.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ pub(crate) mod tests {
assert_eq!(prover.verify(), Ok(()))
}

#[cfg(feature = "dev-graph")]
#[cfg(feature = "test-dev-graph")]
#[test]
fn print_sinsemilla_chip() {
use plotters::prelude::*;
Expand Down
2 changes: 1 addition & 1 deletion halo2_gadgets/src/sinsemilla/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ pub mod tests {
assert_eq!(prover.verify(), Ok(()))
}

#[cfg(feature = "dev-graph")]
#[cfg(feature = "test-dev-graph")]
#[test]
fn print_merkle_chip() {
use plotters::prelude::*;
Expand Down
16 changes: 10 additions & 6 deletions halo2_proofs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ harness = false

[dependencies]
backtrace = { version = "0.3", optional = true }
rayon = "1.5.1"
ff = "0.13"
group = "0.13"
halo2curves = { version = "0.1.0" }
Expand All @@ -58,9 +57,10 @@ tracing = "0.1"
blake2b_simd = "1"
sha3 = "0.9.1"
rand_chacha = "0.3"
maybe-rayon = { version = "0.1.0", default-features = false }

# Developer tooling dependencies
plotters = { version = "0.3.0", optional = true }
plotters = { version = "0.3.0", default-features = false, optional = true }
tabbycat = { version = "0.1", features = ["attributes"], optional = true }

[dev-dependencies]
Expand All @@ -74,8 +74,14 @@ rand_core = { version = "0.6", default-features = false, features = ["getrandom"
getrandom = { version = "0.2", features = ["js"] }

[features]
default = ["batch"]
default = ["batch", "multicore"]
multicore = ["maybe-rayon/threads"]
dev-graph = ["plotters", "tabbycat"]
test-dev-graph = [
"dev-graph",
"plotters/bitmap_backend",
"plotters/bitmap_encoder",
]
gadget-traces = ["backtrace"]
thread-safe-region = []
sanity-checks = []
Expand All @@ -87,6 +93,4 @@ bench = false

[[example]]
name = "circuit-layout"
required-features = ["dev-graph"]


required-features = ["test-dev-graph"]
2 changes: 1 addition & 1 deletion halo2_proofs/benches/commit_zk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use group::ff::Field;
use halo2_proofs::*;
use halo2curves::pasta::pallas::Scalar;
use maybe_rayon::{current_num_threads, prelude::*};
use rand_chacha::rand_core::RngCore;
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
use rayon::{current_num_threads, prelude::*};

fn rand_poly_serial(mut rng: ChaCha20Rng, domain: usize) -> Vec<Scalar> {
// Sample a random polynomial of degree n - 1
Expand Down
2 changes: 1 addition & 1 deletion halo2_proofs/examples/vector-mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ impl<F: Field> NumericInstructions<F> for FieldChip<F> {

#[cfg(feature = "thread-safe-region")]
{
use rayon::prelude::{
use maybe_rayon::prelude::{
IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator,
};
layouter.assign_region(
Expand Down
2 changes: 1 addition & 1 deletion halo2_proofs/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ pub fn recursive_butterfly_arithmetic<Scalar: Field, G: FftGroup<Scalar>>(
a[1] -= &t;
} else {
let (left, right) = a.split_at_mut(n / 2);
rayon::join(
multicore::join(
|| recursive_butterfly_arithmetic(left, n / 2, twiddle_chunk * 2, twiddles),
|| recursive_butterfly_arithmetic(right, n / 2, twiddle_chunk * 2, twiddles),
);
Expand Down
10 changes: 4 additions & 6 deletions halo2_proofs/src/dev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ use ff::FromUniformBytes;
use group::Group;

use crate::circuit::layouter::SyncDeps;
use crate::multicore::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
ParallelSliceMut,
};
use crate::plonk::permutation::keygen::Assembly;
use crate::{
circuit,
Expand All @@ -26,12 +30,6 @@ use crate::{
},
poly::Rotation,
};
use rayon::{
iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
},
slice::ParallelSliceMut,
};

pub mod metadata;
use metadata::Column as ColumnMetadata;
Expand Down
77 changes: 73 additions & 4 deletions halo2_proofs/src/multicore.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,74 @@
//! An interface for dealing with the kinds of parallel computations involved in
//! `halo2`. It's currently just a (very!) thin wrapper around [`rayon`] but may
//! be extended in the future to allow for various parallelism strategies.
#[cfg(all(
feature = "multicore",
target_arch = "wasm32",
not(target_feature = "atomics")
))]
compile_error!(
"The multicore feature flag is not supported on wasm32 architectures without atomics"
);

pub use rayon::{current_num_threads, scope, Scope};
pub use maybe_rayon::{
iter::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
},
join, scope,
slice::ParallelSliceMut,
Scope,
};

#[cfg(feature = "multicore")]
pub use maybe_rayon::current_num_threads;

#[cfg(not(feature = "multicore"))]
pub fn current_num_threads() -> usize {
1
}

#[cfg(not(feature = "multicore"))]
pub trait IndexedParallelIterator: std::iter::Iterator {}

pub trait TryFoldAndReduce<T, E> {
/// Implements `iter.try_fold().try_reduce()` for `rayon::iter::ParallelIterator`,
/// falling back on `Iterator::try_fold` when the `multicore` feature flag is
/// disabled.
/// The `try_fold_and_reduce` function can only be called by a iter with
/// `Result<T, E>` item type because the `fold_op` must meet the trait
/// bounds of both `try_fold` and `try_reduce` from rayon.
fn try_fold_and_reduce(
self,
identity: impl Fn() -> T + Send + Sync,
fold_op: impl Fn(T, Result<T, E>) -> Result<T, E> + Send + Sync,
) -> Result<T, E>;
}

#[cfg(feature = "multicore")]
impl<T, E, I> TryFoldAndReduce<T, E> for I
where
T: Send + Sync,
E: Send + Sync,
I: maybe_rayon::iter::ParallelIterator<Item = Result<T, E>>,
{
fn try_fold_and_reduce(
self,
identity: impl Fn() -> T + Send + Sync,
fold_op: impl Fn(T, Result<T, E>) -> Result<T, E> + Send + Sync,
) -> Result<T, E> {
self.try_fold(&identity, &fold_op)
.try_reduce(&identity, |a, b| fold_op(a, Ok(b)))
}
}

#[cfg(not(feature = "multicore"))]
impl<T, E, I> TryFoldAndReduce<T, E> for I
where
I: std::iter::Iterator<Item = Result<T, E>>,
{
fn try_fold_and_reduce(
mut self,
identity: impl Fn() -> T + Send + Sync,
fold_op: impl Fn(T, Result<T, E>) -> Result<T, E> + Send + Sync,
) -> Result<T, E> {
self.try_fold(identity(), fold_op)
}
}
22 changes: 18 additions & 4 deletions halo2_proofs/src/plonk/permutation/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};

use ff::{Field, PrimeField};
use group::Curve;
use rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator, ParallelSliceMut,
};

use super::{Argument, ProvingKey, VerifyingKey};
use crate::{
arithmetic::{parallelize, CurveAffine},
multicore::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator,
IntoParallelRefMutIterator, ParallelIterator,
},
plonk::{Any, Column, Error},
poly::{
commitment::{Blind, CommitmentScheme, Params},
Expand Down Expand Up @@ -133,12 +133,19 @@ impl Assembly {
&self.columns
}

#[cfg(feature = "multicore")]
/// Returns mappings of the copies.
pub fn mapping(
&self,
) -> impl Iterator<Item = impl IndexedParallelIterator<Item = (usize, usize)> + '_> {
self.mapping.iter().map(|c| c.par_iter().copied())
}

#[cfg(not(feature = "multicore"))]
/// Returns mappings of the copies.
pub fn mapping(&self) -> impl Iterator<Item = impl Iterator<Item = (usize, usize)> + '_> {
self.mapping.iter().map(|c| c.iter().copied())
}
}

#[cfg(feature = "thread-safe-region")]
Expand Down Expand Up @@ -304,6 +311,7 @@ impl Assembly {
&self.columns
}

#[cfg(feature = "multicore")]
/// Returns mappings of the copies.
pub fn mapping(
&self,
Expand All @@ -314,6 +322,12 @@ impl Assembly {
.map(move |j| self.mapping_at_idx(i, j))
})
}

#[cfg(not(feature = "multicore"))]
/// Returns mappings of the copies.
pub fn mapping(&self) -> impl Iterator<Item = impl Iterator<Item = (usize, usize)> + '_> {
(0..self.num_cols).map(move |i| (0..self.col_len).map(move |j| self.mapping_at_idx(i, j)))
}
}

pub(crate) fn build_pk<'params, C: CurveAffine, P: Params<'params, C>>(
Expand Down
5 changes: 4 additions & 1 deletion halo2_proofs/src/plonk/vanishing/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ use ff::{Field, PrimeField};
use group::Curve;
use rand_chacha::ChaCha20Rng;
use rand_core::{RngCore, SeedableRng};
use rayon::{current_num_threads, prelude::*};

use super::Argument;
use crate::{
arithmetic::{eval_polynomial, CurveAffine},
multicore::{
current_num_threads, IndexedParallelIterator, IntoParallelIterator,
IntoParallelRefMutIterator, ParallelIterator, ParallelSliceMut,
},
plonk::{ChallengeX, ChallengeY, Error},
poly::{
self,
Expand Down
12 changes: 6 additions & 6 deletions halo2_proofs/src/plonk/verifier/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use ff::FromUniformBytes;
use group::ff::Field;
use halo2curves::CurveAffine;
use rand_core::{OsRng, RngCore};
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};

use super::{verify_proof, VerificationStrategy};
use crate::{
multicore,
multicore::{
IndexedParallelIterator, IntoParallelIterator, ParallelIterator, TryFoldAndReduce,
},
plonk::{Error, VerifyingKey},
poly::{
commitment::{Params, MSM},
Expand Down Expand Up @@ -123,11 +124,10 @@ where
e
})
})
.try_fold(
.try_fold_and_reduce(
|| params.empty_msm(),
|msm, res| res.map(|proof_msm| accumulate_msm(msm, proof_msm)),
)
.try_reduce(|| params.empty_msm(), |a, b| Ok(accumulate_msm(a, b)));
|acc, res| res.map(|proof_msm| accumulate_msm(acc, proof_msm)),
);

match final_msm {
Ok(msm) => msm.check(),
Expand Down
2 changes: 1 addition & 1 deletion halo2_proofs/src/poly/kzg/multiopen/shplonk.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
mod prover;
mod verifier;

use crate::multicore::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use crate::{
arithmetic::{eval_polynomial, lagrange_interpolate, CurveAffine},
poly::{query::Query, Coeff, Polynomial},
transcript::ChallengeScalar,
};
use ff::Field;
pub use prover::ProverSHPLONK;
use rayon::prelude::*;
use std::{
collections::{btree_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet},
marker::PhantomData,
Expand Down
Loading

0 comments on commit 1a905a9

Please sign in to comment.