Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions halo2-base/src/gates/circuit/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,18 @@ impl<F: ScalarField> BaseCircuitBuilder<F> {
self.core
.phase_manager
.iter()
.map(|pm| pm.break_points.get().expect("break points not set").clone())
.map(|pm| pm.break_points.borrow().as_ref().expect("break points not set").clone())
.collect()
}

/// Sets the break points of the circuit.
pub fn set_break_points(&mut self, break_points: MultiPhaseThreadBreakPoints) {
if break_points.is_empty() {
return;
}
self.core.touch(break_points.len() - 1);
for (pm, bp) in self.core.phase_manager.iter().zip_eq(break_points) {
pm.break_points.set(bp).unwrap();
*pm.break_points.borrow_mut() = Some(bp);
}
}

Expand All @@ -207,6 +211,15 @@ impl<F: ScalarField> BaseCircuitBuilder<F> {
self
}

/// Clears state and copies, effectively resetting the circuit builder.
pub fn clear(&mut self) {
self.core.clear();
for lm in &mut self.lookup_manager {
lm.cells_to_lookup.lock().unwrap().clear();
lm.copy_manager.lock().unwrap().clear();
}
}

/// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists.
/// * `phase`: The challenge phase (as an index) of the gate thread.
pub fn main(&mut self, phase: usize) -> &mut Context<F> {
Expand Down
9 changes: 9 additions & 0 deletions halo2-base/src/gates/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ impl<F: ScalarField> BaseConfig<F> {
MaybeRangeConfig::WithRange(config) => config.gate.max_rows = usable_rows,
}
}

/// Initialization of config at very beginning of `synthesize`.
/// Loads fixed lookup table, if using.
pub fn initialize(&self, layouter: &mut impl Layouter<F>) {
// only load lookup table if we are actually doing lookups
if let MaybeRangeConfig::WithRange(config) = &self.base {
config.load_lookup_table(layouter).expect("load lookup table should not fail");
}
}
}

impl<F: ScalarField> Circuit<F> for BaseCircuitBuilder<F> {
Expand Down
10 changes: 9 additions & 1 deletion halo2-base/src/gates/flex_gate/threads/multi_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ impl<F: ScalarField> MultiPhaseCoreManager<F> {
self
}

/// Clears all threads in all phases and copy manager.
pub fn clear(&mut self) {
for pm in &mut self.phase_manager {
pm.clear();
}
self.copy_manager.lock().unwrap().clear();
}

/// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists.
/// * `phase`: The challenge phase (as an index) of the gate thread.
pub fn main(&mut self, phase: usize) -> &mut Context<F> {
Expand All @@ -88,7 +96,7 @@ impl<F: ScalarField> MultiPhaseCoreManager<F> {
}

/// Populate `self` up to Phase `phase` (inclusive)
fn touch(&mut self, phase: usize) {
pub(crate) fn touch(&mut self, phase: usize) {
while self.phase_manager.len() <= phase {
let _phase = self.phase_manager.len();
let pm = SinglePhaseCoreManager::new(self.witness_gen_only, self.copy_manager.clone())
Expand Down
26 changes: 16 additions & 10 deletions halo2-base/src/gates/flex_gate/threads/single_phase.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{any::TypeId, cell::OnceCell};
use std::{any::TypeId, cell::RefCell};

use getset::CopyGetters;

Expand Down Expand Up @@ -39,7 +39,7 @@ pub struct SinglePhaseCoreManager<F: ScalarField> {
pub(crate) phase: usize,
/// A very simple computation graph for the basic vertical gate. Must be provided as a "pinning"
/// when running the production prover.
pub break_points: OnceCell<ThreadBreakPoints>,
pub break_points: RefCell<Option<ThreadBreakPoints>>,
}

impl<F: ScalarField> SinglePhaseCoreManager<F> {
Expand Down Expand Up @@ -93,6 +93,12 @@ impl<F: ScalarField> SinglePhaseCoreManager<F> {
self
}

/// Clears all threads and copy manager
pub fn clear(&mut self) {
self.threads = vec![];
self.copy_manager.lock().unwrap().clear();
}

/// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists.
pub fn main(&mut self) -> &mut Context<F> {
if self.threads.is_empty() {
Expand Down Expand Up @@ -147,7 +153,8 @@ impl<F: ScalarField> VirtualRegionManager<F> for SinglePhaseCoreManager<F> {

fn assign_raw(&self, (config, usable_rows): &Self::Config, region: &mut Region<F>) {
if self.witness_gen_only {
let break_points = self.break_points.get().expect("break points not set");
let binding = self.break_points.borrow();
let break_points = binding.as_ref().expect("break points not set");
assign_witnesses(&self.threads, config, region, break_points);
} else {
let mut copy_manager = self.copy_manager.lock().unwrap();
Expand All @@ -159,13 +166,12 @@ impl<F: ScalarField> VirtualRegionManager<F> for SinglePhaseCoreManager<F> {
*usable_rows,
self.use_unknown,
);
self.break_points.set(break_points).unwrap_or_else(|break_points| {
assert_eq!(
self.break_points.get().unwrap(),
&break_points,
"previously set break points don't match"
);
});
let mut bp = self.break_points.borrow_mut();
if let Some(bp) = bp.as_ref() {
assert_eq!(bp, &break_points, "break points don't match");
} else {
*bp = Some(break_points);
}
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions halo2-base/src/virtual_region/copy_constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ impl<F: Field + Ord> CopyConstraintManager<F> {
self.assigned_advices.insert(context_cell, cell);
context_cell
}

/// Clears state
pub fn clear(&mut self) {
*self = Self::default();
}
}

impl<F: Field + Ord> Drop for CopyConstraintManager<F> {
Expand Down
4 changes: 2 additions & 2 deletions halo2-base/src/virtual_region/tests/lookups/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ fn test_ram_prover() {
let vk = keygen_vk(&params, &circuit).unwrap();
let pk = keygen_pk(&params, vk, &circuit).unwrap();
let circuit_params = circuit.params();
let break_points = circuit.cpu.break_points.get().unwrap().clone();
let break_points = circuit.cpu.break_points.borrow().clone().unwrap();
drop(circuit);

let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect();
let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len()));
let mut circuit = RAMCircuit::new(memory, ptrs, circuit_params, true);
circuit.cpu.break_points.set(break_points).unwrap();
*circuit.cpu.break_points.borrow_mut() = Some(break_points);
circuit.compute();

let proof = gen_proof(&params, &pk, circuit);
Expand Down