Skip to content

Commit

Permalink
refactor: Refactor supernova module to avoid needless pass-by-value
Browse files Browse the repository at this point in the history
- Modified `SuperNovaAugmentedCircuitParams` to be passed by reference in `test_recursive_circuit_with` function and associated instances in `test.rs`.
- Updated `SuperNovaAugmentedCircuit::new` method to match changes in the `test_recursive_circuit_with` function.
- Revamped arguments handling in `synthesize_non_base_case` and `synthesize` functions in `circuit.rs`, changing from owned types to references to reduce unnecessary cloning.
- Altered function body to accommodate the changes for referenced values in `circuit.rs`.
  • Loading branch information
huitseeker committed Aug 27, 2023
1 parent d04534b commit 012c58c
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 35 deletions.
42 changes: 21 additions & 21 deletions src/supernova/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,16 @@ impl<'a, G: Group, SC: StepCircuit<G::Base>> SuperNovaAugmentedCircuit<'a, G, SC
fn synthesize_non_base_case<CS: ConstraintSystem<<G as Group>::Base>>(
&self,
mut cs: CS,
params: AllocatedNum<G::Base>,
i: AllocatedNum<G::Base>,
z_0: Vec<AllocatedNum<G::Base>>,
z_i: Vec<AllocatedNum<G::Base>>,
params: &AllocatedNum<G::Base>,
i: &AllocatedNum<G::Base>,
z_0: &[AllocatedNum<G::Base>],
z_i: &[AllocatedNum<G::Base>],
U: &[AllocatedRelaxedR1CSInstance<G>],
u: AllocatedR1CSInstance<G>,
T: AllocatedPoint<G>,
u: &AllocatedR1CSInstance<G>,
T: &AllocatedPoint<G>,
arity: usize,
last_augmented_circuit_index: &AllocatedNum<G::Base>,
program_counter: Option<AllocatedNum<G::Base>>,
program_counter: &Option<AllocatedNum<G::Base>>,
num_augmented_circuits: usize,
) -> Result<
(
Expand All @@ -344,8 +344,8 @@ impl<'a, G: Group, SC: StepCircuit<G::Base>> SuperNovaAugmentedCircuit<'a, G, SC
+ 2 * arity // zo, z1
+ num_augmented_circuits * (7 + 2 * self.params.n_limbs), // #num_augmented_circuits * (7 + [X0, X1]*#num_limb)
);
ro.absorb(&params);
ro.absorb(&i);
ro.absorb(params);
ro.absorb(i);

if self.params.is_primary_circuit {
if let Some(program_counter) = program_counter.as_ref() {
Expand All @@ -355,10 +355,10 @@ impl<'a, G: Group, SC: StepCircuit<G::Base>> SuperNovaAugmentedCircuit<'a, G, SC
}
}

for e in &z_0 {
for e in z_0 {
ro.absorb(e);
}
for e in &z_i {
for e in z_i {
ro.absorb(e);
}

Expand All @@ -382,9 +382,9 @@ impl<'a, G: Group, SC: StepCircuit<G::Base>> SuperNovaAugmentedCircuit<'a, G, SC
)?;
let U_fold = U_to_fold.fold_with_r1cs(
cs.namespace(|| "compute fold of U and u"),
&params,
&u,
&T,
params,
u,
T,
self.ro_consts.clone(),
self.params.limb_width,
self.params.n_limbs,
Expand Down Expand Up @@ -488,16 +488,16 @@ impl<'a, G: Group, SC: StepCircuit<G::Base>> SuperNovaAugmentedCircuit<'a, G, SC
let (last_augmented_circuit_index_checked, U_next_non_base, check_non_base_pass) = self
.synthesize_non_base_case(
cs.namespace(|| "synthesize non base case"),
params.clone(),
i.clone(),
z_0.clone(),
z_i.clone(),
&params,
&i,
&z_0,
&z_i,
&U,
u.clone(),
T,
&u,
&T,
arity,
&last_augmented_circuit_index,
program_counter.clone(),
&program_counter,
num_augmented_circuits,
)?;

Expand Down
28 changes: 14 additions & 14 deletions src/supernova/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ where
}

fn print_constraints_name_on_error_index<G1, G2, Ca, Cb>(
err: SuperNovaError,
err: &SuperNovaError,
running_claim: &RunningClaim<G1, G2, Ca, Cb>,
num_augmented_circuits: usize,
) where
Expand All @@ -241,7 +241,7 @@ fn print_constraints_name_on_error_index<G1, G2, Ca, Cb>(
Cb: StepCircuit<G2::Scalar>,
{
match err {
SuperNovaError::UnSatIndex(msg, index) if msg == "r_primary" => {
SuperNovaError::UnSatIndex(msg, index) if *msg == "r_primary" => {
let circuit_primary: SuperNovaAugmentedCircuit<'_, G2, Ca> = SuperNovaAugmentedCircuit::new(
&running_claim.params.augmented_circuit_params_primary,
None,
Expand All @@ -252,10 +252,10 @@ fn print_constraints_name_on_error_index<G1, G2, Ca, Cb>(
let mut cs: TestShapeCS<G1> = TestShapeCS::new();
let _ = circuit_primary.synthesize(&mut cs);
cs.constraints
.get(index)
.get(*index)
.tap_some(|constraint| debug!("{msg} failed at constraint {}", constraint.3));
}
SuperNovaError::UnSatIndex(msg, index) if msg == "r_secondary" || msg == "l_secondary" => {
SuperNovaError::UnSatIndex(msg, index) if *msg == "r_secondary" || *msg == "l_secondary" => {
let circuit_secondary: SuperNovaAugmentedCircuit<'_, G1, Cb> = SuperNovaAugmentedCircuit::new(
&running_claim.params.augmented_circuit_params_secondary,
None,
Expand All @@ -266,7 +266,7 @@ fn print_constraints_name_on_error_index<G1, G2, Ca, Cb>(
let mut cs: TestShapeCS<G2> = TestShapeCS::new();
let _ = circuit_secondary.synthesize(&mut cs);
cs.constraints
.get(index)
.get(*index)
.tap_some(|constraint| debug!("{msg} failed at constraint {}", constraint.3));
}
_ => (),
Expand Down Expand Up @@ -435,7 +435,7 @@ where
.verify(&running_claim1, &z0_primary, &z0_secondary)
.map_err(|err| {
print_constraints_name_on_error_index(
err,
&err,
&running_claim1,
test_rom.num_augmented_circuits(),
)
Expand All @@ -449,7 +449,7 @@ where
.verify(&running_claim2, &z0_primary, &z0_secondary)
.map_err(|err| {
print_constraints_name_on_error_index(
err,
&err,
&running_claim2,
test_rom.num_augmented_circuits(),
)
Expand Down Expand Up @@ -485,8 +485,8 @@ fn test_trivial_nivc() {

// In the following we use 1 to refer to the primary, and 2 to refer to the secondary circuit
fn test_recursive_circuit_with<G1, G2>(
primary_params: SuperNovaAugmentedCircuitParams,
secondary_params: SuperNovaAugmentedCircuitParams,
primary_params: &SuperNovaAugmentedCircuitParams,
secondary_params: &SuperNovaAugmentedCircuitParams,
ro_consts1: ROConstantsCircuit<G2>,
ro_consts2: ROConstantsCircuit<G1>,
num_constraints_primary: usize,
Expand All @@ -499,7 +499,7 @@ fn test_recursive_circuit_with<G1, G2>(
let step_circuit1 = TrivialTestCircuit::default();
let arity1 = step_circuit1.arity();
let circuit1: SuperNovaAugmentedCircuit<'_, G2, TrivialTestCircuit<<G2 as Group>::Base>> =
SuperNovaAugmentedCircuit::new(&primary_params, None, &step_circuit1, ro_consts1.clone(), 2);
SuperNovaAugmentedCircuit::new(primary_params, None, &step_circuit1, ro_consts1.clone(), 2);
let mut cs: ShapeCS<G1> = ShapeCS::new();
if let Err(e) = circuit1.synthesize(&mut cs) {
panic!("{}", e)
Expand All @@ -512,7 +512,7 @@ fn test_recursive_circuit_with<G1, G2>(
let arity2 = step_circuit2.arity();
let circuit2: SuperNovaAugmentedCircuit<'_, G1, TrivialSecondaryCircuit<<G1 as Group>::Base>> =
SuperNovaAugmentedCircuit::new(
&secondary_params,
secondary_params,
None,
&step_circuit2,
ro_consts2.clone(),
Expand Down Expand Up @@ -542,7 +542,7 @@ fn test_recursive_circuit_with<G1, G2>(
);
let step_circuit = TrivialTestCircuit::default();
let circuit1: SuperNovaAugmentedCircuit<'_, G2, TrivialTestCircuit<<G2 as Group>::Base>> =
SuperNovaAugmentedCircuit::new(&primary_params, Some(inputs1), &step_circuit, ro_consts1, 2);
SuperNovaAugmentedCircuit::new(primary_params, Some(inputs1), &step_circuit, ro_consts1, 2);
if let Err(e) = circuit1.synthesize(&mut cs1) {
panic!("{}", e)
}
Expand All @@ -568,7 +568,7 @@ fn test_recursive_circuit_with<G1, G2>(
let step_circuit = TrivialSecondaryCircuit::default();
let circuit2: SuperNovaAugmentedCircuit<'_, G1, TrivialSecondaryCircuit<<G1 as Group>::Base>> =
SuperNovaAugmentedCircuit::new(
&secondary_params,
secondary_params,
Some(inputs2),
&step_circuit,
ro_consts2,
Expand All @@ -591,5 +591,5 @@ fn test_recursive_circuit() {
let ro_consts1: ROConstantsCircuit<G2> = PoseidonConstantsCircuit::default();
let ro_consts2: ROConstantsCircuit<G1> = PoseidonConstantsCircuit::default();

test_recursive_circuit_with::<G1, G2>(params1, params2, ro_consts1, ro_consts2, 9835, 12035);
test_recursive_circuit_with::<G1, G2>(&params1, &params2, ro_consts1, ro_consts2, 9835, 12035);
}

0 comments on commit 012c58c

Please sign in to comment.