diff --git a/Cargo.toml b/Cargo.toml index de037c4c..803beade 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,6 @@ num-bigint = "0.4.3" num-integer = "0.1.45" num-traits = "0.2.15" rand = "0.8" -rand_chacha = "0.3.1" halo2_curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", tag = "0.3.0", package = "halo2curves" } # system_halo2 @@ -25,6 +24,7 @@ halo2_wrong_ecc = { git = "https://github.com/privacy-scaling-explorations/halo2 poseidon = { git = "https://github.com/privacy-scaling-explorations/poseidon", tag = "v2022_10_22", optional = true } [dev-dependencies] +rand_chacha = "0.3.1" paste = "1.0.7" # system_halo2 diff --git a/examples/evm-verifier-with-accumulator.rs b/examples/evm-verifier-with-accumulator.rs index 69def21e..66fa8ddf 100644 --- a/examples/evm-verifier-with-accumulator.rs +++ b/examples/evm-verifier-with-accumulator.rs @@ -280,12 +280,12 @@ mod aggregation { let accumulators = snarks .iter() .flat_map(|snark| { + let protocol = snark.protocol.loaded(loader); let instances = assign_instances(&snark.instances); let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); - let proof = - Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); - Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap() }) .collect_vec(); @@ -555,11 +555,12 @@ fn gen_aggregation_evm_verifier( vk, Config::kzg() .with_num_instance(num_instance.clone()) - .with_accumulator_indices(accumulator_indices), + .with_accumulator_indices(Some(accumulator_indices)), ); let loader = EvmLoader::new::(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances(num_instance); let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); diff --git a/examples/evm-verifier.rs b/examples/evm-verifier.rs index b51a9a30..9ed70b36 100644 --- a/examples/evm-verifier.rs +++ b/examples/evm-verifier.rs @@ -216,7 +216,8 @@ fn gen_evm_verifier( ); let loader = EvmLoader::new::(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let protocol = protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances(num_instance); let proof = Plonk::read_proof(&svk, &protocol, &instances, &mut transcript).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 4e8da5fb..9749924d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,10 +20,14 @@ pub enum Error { } #[derive(Clone, Debug)] -pub struct Protocol { +pub struct Protocol +where + C: util::arithmetic::CurveAffine, + L: loader::Loader, +{ // Common description pub domain: util::arithmetic::Domain, - pub preprocessed: Vec, + pub preprocessed: Vec, pub num_instance: Vec, pub num_witness: Vec, pub num_challenge: Vec, @@ -31,7 +35,7 @@ pub struct Protocol { pub queries: Vec, pub quotient: util::protocol::QuotientPolynomial, // Minor customization - pub transcript_initial_state: Option, + pub transcript_initial_state: Option, pub instance_committing_key: Option>, pub linearization: Option, pub accumulator_indices: Vec>, diff --git a/src/loader.rs b/src/loader.rs index 8c39bae0..5a040f9f 100644 --- a/src/loader.rs +++ b/src/loader.rs @@ -19,15 +19,6 @@ pub trait LoadedEcPoint: Clone + Debug + PartialEq { type Loader: Loader; fn loader(&self) -> &Self::Loader; - - fn multi_scalar_multiplication( - pairs: impl IntoIterator< - Item = ( - >::LoadedScalar, - Self, - ), - >, - ) -> Self; } pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { @@ -43,15 +34,6 @@ pub trait LoadedScalar: Clone + Debug + PartialEq + FieldOps { FieldOps::invert(self) } - fn batch_invert<'a>(values: impl IntoIterator) - where - Self: 'a, - { - values - .into_iter() - .for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone())) - } - fn pow_const(&self, mut exp: u64) -> Self { assert!(exp > 0); @@ -102,6 +84,12 @@ pub trait EcPointLoader { lhs: &Self::LoadedEcPoint, rhs: &Self::LoadedEcPoint, ) -> Result<(), Error>; + + fn multi_scalar_multiplication( + pairs: &[(Self::LoadedScalar, Self::LoadedEcPoint)], + ) -> Self::LoadedEcPoint + where + Self: ScalarLoader; } pub trait ScalarLoader { @@ -226,6 +214,15 @@ pub trait ScalarLoader { .iter() .fold(self.load_one(), |acc, value| acc * *value) } + + fn batch_invert<'a>(values: impl IntoIterator) + where + Self::LoadedScalar: 'a, + { + values + .into_iter() + .for_each(|value| *value = LoadedScalar::invert(value).unwrap_or_else(|| value.clone())) + } } pub trait Loader: diff --git a/src/loader/evm/loader.rs b/src/loader/evm/loader.rs index 06ab7fd8..7d1dbb94 100644 --- a/src/loader/evm/loader.rs +++ b/src/loader/evm/loader.rs @@ -596,17 +596,6 @@ where fn loader(&self) -> &Rc { &self.loader } - - fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { - pairs - .into_iter() - .map(|(scalar, ec_point)| match scalar.value { - Value::Constant(constant) if constant == U256::one() => ec_point, - _ => ec_point.loader.ec_point_scalar_mul(&ec_point, &scalar), - }) - .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point)) - .unwrap() - } } #[derive(Clone)] @@ -759,73 +748,12 @@ impl> LoadedScalar for Scalar { fn loader(&self) -> &Rc { &self.loader } - - fn batch_invert<'a>(values: impl IntoIterator) { - let values = values.into_iter().collect_vec(); - let loader = &values.first().unwrap().loader; - let products = iter::once(values[0].clone()) - .chain( - iter::repeat_with(|| loader.allocate(0x20)) - .map(|ptr| loader.scalar(Value::Memory(ptr))) - .take(values.len() - 1), - ) - .collect_vec(); - - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); - } - - loader.push(products.first().unwrap()); - for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { - loader.push(value); - loader.code.borrow_mut().mulmod(); - if idx < values.len() - 2 { - loader.code.borrow_mut().dup(0); - } - loader.code.borrow_mut().push(product.ptr()).mstore(); - } - - let inv = loader.invert(products.last().unwrap()); - - loader.code.borrow_mut().push(loader.scalar_modulus); - for _ in 2..values.len() { - loader.code.borrow_mut().dup(0); - } - - loader.push(&inv); - for (value, product) in values.iter().rev().zip( - products - .iter() - .rev() - .skip(1) - .map(Some) - .chain(iter::once(None)), - ) { - if let Some(product) = product { - loader.push(value); - loader - .code - .borrow_mut() - .dup(2) - .dup(2) - .push(product.ptr()) - .mload() - .mulmod() - .push(value.ptr()) - .mstore() - .mulmod(); - } else { - loader.code.borrow_mut().push(value.ptr()).mstore(); - } - } - } } impl EcPointLoader for Rc where C: CurveAffine, - C::Scalar: PrimeField, + C::ScalarExt: PrimeField, { type LoadedEcPoint = EcPoint; @@ -839,6 +767,19 @@ where fn ec_point_assert_eq(&self, _: &str, _: &EcPoint, _: &EcPoint) -> Result<(), Error> { unimplemented!() } + + fn multi_scalar_multiplication( + pairs: &[(>::LoadedScalar, EcPoint)], + ) -> EcPoint { + pairs + .iter() + .map(|(scalar, ec_point)| match scalar.value { + Value::Constant(constant) if U256::one() == constant => ec_point.clone(), + _ => ec_point.loader.ec_point_scalar_mul(ec_point, scalar), + }) + .reduce(|acc, ec_point| acc.loader.ec_point_add(&acc, &ec_point)) + .unwrap() + } } impl> ScalarLoader for Rc { @@ -977,6 +918,67 @@ impl> ScalarLoader for Rc { self.scalar(Value::Memory(ptr)) } + + fn batch_invert<'a>(values: impl IntoIterator) { + let values = values.into_iter().collect_vec(); + let loader = &values.first().unwrap().loader; + let products = iter::once(values[0].clone()) + .chain( + iter::repeat_with(|| loader.allocate(0x20)) + .map(|ptr| loader.scalar(Value::Memory(ptr))) + .take(values.len() - 1), + ) + .collect_vec(); + + loader.code.borrow_mut().push(loader.scalar_modulus); + for _ in 2..values.len() { + loader.code.borrow_mut().dup(0); + } + + loader.push(products.first().unwrap()); + for (idx, (value, product)) in values.iter().zip(products.iter()).skip(1).enumerate() { + loader.push(value); + loader.code.borrow_mut().mulmod(); + if idx < values.len() - 2 { + loader.code.borrow_mut().dup(0); + } + loader.code.borrow_mut().push(product.ptr()).mstore(); + } + + let inv = loader.invert(products.last().unwrap()); + + loader.code.borrow_mut().push(loader.scalar_modulus); + for _ in 2..values.len() { + loader.code.borrow_mut().dup(0); + } + + loader.push(&inv); + for (value, product) in values.iter().rev().zip( + products + .iter() + .rev() + .skip(1) + .map(Some) + .chain(iter::once(None)), + ) { + if let Some(product) = product { + loader.push(value); + loader + .code + .borrow_mut() + .dup(2) + .dup(2) + .push(product.ptr()) + .mload() + .mulmod() + .push(value.ptr()) + .mstore() + .mulmod(); + } else { + loader.code.borrow_mut().push(value.ptr()).mstore(); + } + } + } } impl Loader for Rc diff --git a/src/loader/halo2.rs b/src/loader/halo2.rs index 9eae74e3..39e95c23 100644 --- a/src/loader/halo2.rs +++ b/src/loader/halo2.rs @@ -1,3 +1,7 @@ +use crate::{util::arithmetic::CurveAffine, Protocol}; +use halo2_proofs::circuit; +use std::rc::Rc; + pub(crate) mod loader; mod shim; @@ -27,3 +31,41 @@ mod util { impl>> Valuetools for I {} } + +impl Protocol +where + C: CurveAffine, +{ + pub fn loaded_preprocessed_as_witness<'a, EccChip: EccInstructions<'a, C>>( + &self, + loader: &Rc>, + ) -> Protocol>> { + let preprocessed = self + .preprocessed + .iter() + .map(|preprocessed| loader.assign_ec_point(circuit::Value::known(*preprocessed))) + .collect(); + let transcript_initial_state = + self.transcript_initial_state + .as_ref() + .map(|transcript_initial_state| { + loader.assign_scalar(circuit::Value::known( + loader.scalar_chip().integer(*transcript_initial_state), + )) + }); + Protocol { + domain: self.domain.clone(), + preprocessed, + num_instance: self.num_instance.clone(), + num_witness: self.num_witness.clone(), + num_challenge: self.num_challenge.clone(), + evaluations: self.evaluations.clone(), + queries: self.queries.clone(), + quotient: self.quotient.clone(), + transcript_initial_state, + instance_committing_key: self.instance_committing_key.clone(), + linearization: self.linearization.clone(), + accumulator_indices: self.accumulator_indices.clone(), + } + } +} diff --git a/src/loader/halo2/loader.rs b/src/loader/halo2/loader.rs index 0d288ce7..d446c716 100644 --- a/src/loader/halo2/loader.rs +++ b/src/loader/halo2/loader.rs @@ -11,9 +11,7 @@ use crate::{ use halo2_proofs::circuit; use std::{ cell::{Ref, RefCell, RefMut}, - collections::btree_map::{BTreeMap, Entry}, fmt::{self, Debug}, - iter, marker::PhantomData, ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Sub, SubAssign}, rc::Rc, @@ -25,7 +23,6 @@ pub struct Halo2Loader<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { ctx: RefCell, num_scalar: RefCell, num_ec_point: RefCell, - const_ec_point: RefCell>>, _marker: PhantomData, #[cfg(test)] row_meterings: RefCell>, @@ -38,7 +35,6 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc ctx: RefCell::new(ctx), num_scalar: RefCell::default(), num_ec_point: RefCell::default(), - const_ec_point: RefCell::default(), #[cfg(test)] row_meterings: RefCell::default(), _marker: PhantomData, @@ -61,16 +57,16 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc self.ctx.borrow() } - pub(crate) fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { + pub fn ctx_mut(&self) -> RefMut<'_, EccChip::Context> { self.ctx.borrow_mut() } - pub fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> Scalar<'a, C, EccChip> { + fn assign_const_scalar(self: &Rc, constant: C::Scalar) -> Scalar<'a, C, EccChip> { let assigned = self .scalar_chip() .assign_constant(&mut self.ctx_mut(), constant) .unwrap(); - self.scalar(Value::Assigned(assigned)) + self.scalar_from_assigned(assigned) } pub fn assign_scalar( @@ -81,10 +77,17 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .scalar_chip() .assign_integer(&mut self.ctx_mut(), scalar) .unwrap(); + self.scalar_from_assigned(assigned) + } + + pub fn scalar_from_assigned( + self: &Rc, + assigned: EccChip::AssignedScalar, + ) -> Scalar<'a, C, EccChip> { self.scalar(Value::Assigned(assigned)) } - pub(crate) fn scalar( + fn scalar( self: &Rc, value: Value, ) -> Scalar<'a, C, EccChip> { @@ -97,23 +100,12 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc } } - pub fn assign_const_ec_point(self: &Rc, constant: C) -> EcPoint<'a, C, EccChip> { - let coordinates = constant.coordinates().unwrap(); - match self - .const_ec_point - .borrow_mut() - .entry((*coordinates.x(), *coordinates.y())) - { - Entry::Occupied(entry) => entry.get().clone(), - Entry::Vacant(entry) => { - let assigned = self - .ecc_chip() - .assign_point(&mut self.ctx_mut(), circuit::Value::known(constant)) - .unwrap(); - let ec_point = self.ec_point(assigned); - entry.insert(ec_point).clone() - } - } + fn assign_const_ec_point(self: &Rc, constant: C) -> EcPoint<'a, C, EccChip> { + self.ec_point_from_assigned( + self.ecc_chip() + .assign_constant(&mut self.ctx_mut(), constant) + .unwrap(), + ) } pub fn assign_ec_point( @@ -124,16 +116,26 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Halo2Loader<'a, C, Ecc .ecc_chip() .assign_point(&mut self.ctx_mut(), ec_point) .unwrap(); - self.ec_point(assigned) + self.ec_point_from_assigned(assigned) + } + + pub fn ec_point_from_assigned( + self: &Rc, + assigned: EccChip::AssignedEcPoint, + ) -> EcPoint<'a, C, EccChip> { + self.ec_point(Value::Assigned(assigned)) } - fn ec_point(self: &Rc, assigned: EccChip::AssignedEcPoint) -> EcPoint<'a, C, EccChip> { + fn ec_point( + self: &Rc, + value: Value, + ) -> EcPoint<'a, C, EccChip> { let index = *self.num_ec_point.borrow(); *self.num_ec_point.borrow_mut() += 1; EcPoint { loader: self.clone(), index, - assigned, + value, } } @@ -305,7 +307,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Scalar<'a, C, EccChip> &self.loader } - pub(crate) fn assigned(&self) -> EccChip::AssignedScalar { + pub fn assigned(&self) -> EccChip::AssignedScalar { match &self.value { Value::Constant(constant) => self.loader.assign_const_scalar(*constant).assigned(), Value::Assigned(assigned) => assigned.clone(), @@ -451,12 +453,15 @@ impl<'a, 'b, C: CurveAffine, EccChip: EccInstructions<'a, C>> MulAssign<&'b Self pub struct EcPoint<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> { loader: Rc>, index: usize, - assigned: EccChip::AssignedEcPoint, + value: Value, } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPoint<'a, C, EccChip> { pub fn assigned(&self) -> EccChip::AssignedEcPoint { - self.assigned.clone() + match &self.value { + Value::Constant(constant) => self.loader.assign_const_ec_point(*constant).assigned(), + Value::Assigned(assigned) => assigned.clone(), + } } } @@ -474,65 +479,13 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> LoadedEcPoint fn loader(&self) -> &Self::Loader { &self.loader } - - fn multi_scalar_multiplication( - pairs: impl IntoIterator, Self)>, - ) -> Self { - let pairs = pairs.into_iter().collect_vec(); - let loader = &pairs[0].0.loader; - - let (non_scaled, scaled) = pairs.iter().fold( - (Vec::new(), Vec::new()), - |(mut non_scaled, mut scaled), (scalar, ec_point)| { - if matches!(scalar.value, Value::Constant(constant) if constant == C::Scalar::one()) - { - non_scaled.push(ec_point.assigned()); - } else { - scaled.push((ec_point.assigned(), scalar.assigned())) - } - (non_scaled, scaled) - }, - ); - - let output = iter::empty() - .chain(if scaled.is_empty() { - None - } else { - Some( - loader - .ecc_chip - .borrow_mut() - .multi_scalar_multiplication(&mut loader.ctx_mut(), scaled) - .unwrap(), - ) - }) - .chain(non_scaled) - .reduce(|acc, ec_point| { - EccInstructions::add( - loader.ecc_chip().deref(), - &mut loader.ctx_mut(), - &acc, - &ec_point, - ) - .unwrap() - }) - .map(|output| { - loader - .ecc_chip() - .normalize(&mut loader.ctx_mut(), &output) - .unwrap() - }) - .unwrap(); - - loader.ec_point(output) - } } impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> Debug for EcPoint<'a, C, EccChip> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("EcPoint") .field("index", &self.index) - .field("assigned", &self.assigned) + .field("value", &self.value) .finish() } } @@ -596,7 +549,7 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader type LoadedEcPoint = EcPoint<'a, C, EccChip>; fn ec_point_load_const(&self, ec_point: &C) -> EcPoint<'a, C, EccChip> { - self.assign_const_ec_point(*ec_point) + self.ec_point(Value::Constant(*ec_point)) } fn ec_point_assert_eq( @@ -605,9 +558,100 @@ impl<'a, C: CurveAffine, EccChip: EccInstructions<'a, C>> EcPointLoader lhs: &EcPoint<'a, C, EccChip>, rhs: &EcPoint<'a, C, EccChip>, ) -> Result<(), crate::Error> { - self.ecc_chip() - .assert_equal(&mut self.ctx_mut(), &lhs.assigned(), &rhs.assigned()) - .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + match (&lhs.value, &rhs.value) { + (Value::Constant(lhs), Value::Constant(rhs)) => { + assert_eq!(lhs, rhs); + Ok(()) + } + (Value::Constant(constant), Value::Assigned(assigned)) + | (Value::Assigned(assigned), Value::Constant(constant)) => { + let constant = self.assign_const_ec_point(*constant).assigned(); + self.ecc_chip() + .assert_equal(&mut self.ctx_mut(), assigned, &constant) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())) + } + (Value::Assigned(lhs), Value::Assigned(rhs)) => self + .ecc_chip() + .assert_equal(&mut self.ctx_mut(), lhs, rhs) + .map_err(|_| crate::Error::AssertionFailure(annotation.to_string())), + } + } + + fn multi_scalar_multiplication( + pairs: &[( + >::LoadedScalar, + EcPoint<'a, C, EccChip>, + )], + ) -> EcPoint<'a, C, EccChip> { + let loader = &pairs[0].0.loader; + + let (constant, fixed_base, variable_base_non_scaled, variable_base_scaled) = + pairs.iter().fold( + (C::identity(), Vec::new(), Vec::new(), Vec::new()), + |( + mut constant, + mut fixed_base, + mut variable_base_non_scaled, + mut variable_base_scaled, + ), + (scalar, ec_point)| { + match (&ec_point.value, &scalar.value) { + (Value::Constant(ec_point), Value::Constant(scalar)) => { + constant = (*ec_point * scalar + constant).into() + } + (Value::Constant(ec_point), Value::Assigned(scalar)) => { + fixed_base.push((scalar.clone(), *ec_point)) + } + (Value::Assigned(ec_point), Value::Constant(scalar)) + if scalar.eq(&C::Scalar::one()) => + { + variable_base_non_scaled.push(ec_point.clone()); + } + (Value::Assigned(ec_point), _) => { + variable_base_scaled.push((scalar.assigned(), ec_point.clone())) + } + }; + ( + constant, + fixed_base, + variable_base_non_scaled, + variable_base_scaled, + ) + }, + ); + + let fixed_base_msm = (!fixed_base.is_empty()).then(|| { + loader + .ecc_chip + .borrow_mut() + .fixed_base_msm(&mut loader.ctx_mut(), &fixed_base) + .unwrap() + }); + let variable_base_msm = (!variable_base_scaled.is_empty()).then(|| { + loader + .ecc_chip + .borrow_mut() + .variable_base_msm(&mut loader.ctx_mut(), &variable_base_scaled) + .unwrap() + }); + let output = loader + .ecc_chip() + .sum_with_const( + &mut loader.ctx_mut(), + &variable_base_non_scaled + .into_iter() + .chain(fixed_base_msm) + .chain(variable_base_msm) + .collect_vec(), + constant, + ) + .unwrap(); + let normalized = loader + .ecc_chip() + .normalize(&mut loader.ctx_mut(), &output) + .unwrap(); + + loader.ec_point_from_assigned(normalized) } } diff --git a/src/loader/halo2/shim.rs b/src/loader/halo2/shim.rs index 67f06cc9..97921d24 100644 --- a/src/loader/halo2/shim.rs +++ b/src/loader/halo2/shim.rs @@ -98,17 +98,23 @@ pub trait EccInstructions<'a, C: CurveAffine>: Clone + Debug { point: Value, ) -> Result; - fn add( + fn sum_with_const( &self, ctx: &mut Self::Context, - p0: &Self::AssignedEcPoint, - p1: &Self::AssignedEcPoint, + values: &[Self::AssignedEcPoint], + constant: C, ) -> Result; - fn multi_scalar_multiplication( + fn fixed_base_msm( &mut self, ctx: &mut Self::Context, - pairs: Vec<(Self::AssignedEcPoint, Self::AssignedScalar)>, + pairs: &[(Self::AssignedScalar, C)], + ) -> Result; + + fn variable_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[(Self::AssignedScalar, Self::AssignedEcPoint)], ) -> Result; fn normalize( @@ -146,6 +152,7 @@ mod halo2_wrong { AssignedPoint, BaseFieldEccChip, }; use rand::rngs::OsRng; + use std::iter; impl<'a, F: FieldExt> Context for RegionCtx<'a, F> { fn constrain_equal(&mut self, lhs: Cell, rhs: Cell) -> Result<(), Error> { @@ -348,21 +355,51 @@ mod halo2_wrong { self.assign_point(ctx, point) } - fn add( + fn sum_with_const( &self, ctx: &mut Self::Context, - p0: &Self::AssignedEcPoint, - p1: &Self::AssignedEcPoint, + values: &[Self::AssignedEcPoint], + constant: C, + ) -> Result { + if values.is_empty() { + return self.assign_constant(ctx, constant); + } + + iter::empty() + .chain( + (!bool::from(constant.is_identity())) + .then(|| self.assign_constant(ctx, constant)), + ) + .chain(values.iter().cloned().map(Ok)) + .reduce(|acc, ec_point| self.add(ctx, &acc?, &ec_point?)) + .unwrap() + } + + fn fixed_base_msm( + &mut self, + ctx: &mut Self::Context, + pairs: &[(Self::AssignedScalar, C)], ) -> Result { - self.add(ctx, p0, p1) + // FIXME: Implement fixed base MSM in halo2_wrong + let pairs = pairs + .iter() + .map(|(scalar, base)| { + Ok::<_, Error>((scalar.clone(), self.assign_constant(ctx, *base)?)) + }) + .collect::, _>>()?; + self.variable_base_msm(ctx, &pairs) } - fn multi_scalar_multiplication( + fn variable_base_msm( &mut self, ctx: &mut Self::Context, - pairs: Vec<(Self::AssignedEcPoint, Self::AssignedScalar)>, + pairs: &[(Self::AssignedScalar, Self::AssignedEcPoint)], ) -> Result { const WINDOW_SIZE: usize = 3; + let pairs = pairs + .iter() + .map(|(scalar, base)| (base.clone(), scalar.clone())) + .collect_vec(); match self.mul_batch_1d_horizontal(ctx, pairs.clone(), WINDOW_SIZE) { Err(_) => { if self.assign_aux(ctx, WINDOW_SIZE, pairs.len()).is_err() { diff --git a/src/loader/native.rs b/src/loader/native.rs index 6bf9f7c4..1451ff76 100644 --- a/src/loader/native.rs +++ b/src/loader/native.rs @@ -19,15 +19,6 @@ impl LoadedEcPoint for C { fn loader(&self) -> &NativeLoader { &LOADER } - - fn multi_scalar_multiplication(pairs: impl IntoIterator) -> Self { - pairs - .into_iter() - .map(|(scalar, base)| base * scalar) - .reduce(|acc, value| acc + value) - .unwrap() - .to_affine() - } } impl FieldOps for F { @@ -61,6 +52,17 @@ impl EcPointLoader for NativeLoader { .then_some(()) .ok_or_else(|| Error::AssertionFailure(annotation.to_string())) } + + fn multi_scalar_multiplication( + pairs: &[(>::LoadedScalar, C)], + ) -> C { + pairs + .iter() + .map(|(scalar, base)| *base * scalar) + .reduce(|acc, value| acc + value) + .unwrap() + .to_affine() + } } impl ScalarLoader for NativeLoader { diff --git a/src/pcs/kzg/multiopen/bdfg21.rs b/src/pcs/kzg/multiopen/bdfg21.rs index 287700d7..7f70ca61 100644 --- a/src/pcs/kzg/multiopen/bdfg21.rs +++ b/src/pcs/kzg/multiopen/bdfg21.rs @@ -203,8 +203,8 @@ fn query_set_coeffs>( }) .collect_vec(); - T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); - T::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); + T::Loader::batch_invert(coeffs.iter_mut().flat_map(QuerySetCoeff::denoms)); coeffs.iter_mut().for_each(QuerySetCoeff::evaluate); coeffs diff --git a/src/system/halo2.rs b/src/system/halo2.rs index bf3e4091..3d7cf140 100644 --- a/src/system/halo2.rs +++ b/src/system/halo2.rs @@ -70,8 +70,11 @@ impl Config { self } - pub fn with_accumulator_indices(mut self, accumulator_indices: Vec<(usize, usize)>) -> Self { - self.accumulator_indices = Some(accumulator_indices); + pub fn with_accumulator_indices( + mut self, + accumulator_indices: Option>, + ) -> Self { + self.accumulator_indices = accumulator_indices; self } } @@ -202,7 +205,7 @@ impl<'a, F: FieldExt> Polynomials<'a, F> { degree - 1 }; - let num_phase = *cs.advice_column_phase().iter().max().unwrap() as usize + 1; + let num_phase = *cs.advice_column_phase().iter().max().unwrap_or(&0) as usize + 1; let remapping = |phase: Vec| { let num = phase.iter().fold(vec![0; num_phase], |mut num, phase| { num[*phase as usize] += 1; diff --git a/src/system/halo2/test/kzg.rs b/src/system/halo2/test/kzg.rs index 0b071175..3b7e1694 100644 --- a/src/system/halo2/test/kzg.rs +++ b/src/system/halo2/test/kzg.rs @@ -45,7 +45,7 @@ macro_rules! halo2_kzg_config { $crate::system::halo2::Config::kzg() .set_zk($zk) .with_num_proof($num_proof) - .with_accumulator_indices($accumulator_indices) + .with_accumulator_indices(Some($accumulator_indices)) }; } diff --git a/src/system/halo2/test/kzg/evm.rs b/src/system/halo2/test/kzg/evm.rs index 4ce850c1..4e57d369 100644 --- a/src/system/halo2/test/kzg/evm.rs +++ b/src/system/halo2/test/kzg/evm.rs @@ -37,16 +37,17 @@ macro_rules! halo2_kzg_evm_verify { let runtime_code = { let svk = $params.get_g()[0].into(); let dk = ($params.g2(), $params.s_g2()).into(); - let mut transcript = EvmTranscript::<_, Rc, _, _>::new(loader.clone()); + let protocol = $protocol.loaded(&loader); + let mut transcript = EvmTranscript::<_, Rc, _, _>::new(&loader); let instances = transcript.load_instances( $instances .iter() .map(|instances| instances.len()) .collect_vec(), ); - let proof = <$plonk_verifier>::read_proof(&svk, $protocol, &instances, &mut transcript) + let proof = <$plonk_verifier>::read_proof(&svk, &protocol, &instances, &mut transcript) .unwrap(); - <$plonk_verifier>::verify(&svk, &dk, $protocol, &instances, &proof).unwrap(); + <$plonk_verifier>::verify(&svk, &dk, &protocol, &instances, &proof).unwrap(); loader.runtime_code() }; diff --git a/src/system/halo2/test/kzg/halo2.rs b/src/system/halo2/test/kzg/halo2.rs index 1bd332a0..18767c98 100644 --- a/src/system/halo2/test/kzg/halo2.rs +++ b/src/system/halo2/test/kzg/halo2.rs @@ -89,12 +89,12 @@ pub fn accumulate<'a>( let mut accumulators = snarks .iter() .flat_map(|snark| { + let protocol = snark.protocol.loaded(loader); let instances = assign_instances(&snark.instances); let mut transcript = PoseidonTranscript::, _>::new(loader, snark.proof()); - let proof = - Plonk::read_proof(svk, &snark.protocol, &instances, &mut transcript).unwrap(); - Plonk::succinct_verify(svk, &snark.protocol, &instances, &proof).unwrap() + let proof = Plonk::read_proof(svk, &protocol, &instances, &mut transcript).unwrap(); + Plonk::succinct_verify(svk, &protocol, &instances, &proof).unwrap() }) .collect_vec(); diff --git a/src/system/halo2/transcript/evm.rs b/src/system/halo2/transcript/evm.rs index 77aca5cf..e461bc05 100644 --- a/src/system/halo2/transcript/evm.rs +++ b/src/system/halo2/transcript/evm.rs @@ -32,13 +32,13 @@ where C: CurveAffine, C::Scalar: PrimeField, { - pub fn new(loader: Rc) -> Self { + pub fn new(loader: &Rc) -> Self { let ptr = loader.allocate(0x20); assert_eq!(ptr, 0); let mut buf = MemoryChunk::new(ptr); buf.extend(0x20); Self { - loader, + loader: loader.clone(), stream: 0, buf, _marker: PhantomData, diff --git a/src/system/halo2/transcript/halo2.rs b/src/system/halo2/transcript/halo2.rs index 43701bd7..bb7e8e23 100644 --- a/src/system/halo2/transcript/halo2.rs +++ b/src/system/halo2/transcript/halo2.rs @@ -1,11 +1,11 @@ use crate::{ loader::{ - halo2::{self, EcPoint, EccInstructions, Halo2Loader, IntegerInstructions, Scalar}, + halo2::{EcPoint, EccInstructions, Halo2Loader, IntegerInstructions, Scalar}, native::{self, NativeLoader}, Loader, ScalarLoader, }, util::{ - arithmetic::{fe_from_big, fe_to_big, CurveAffine, FieldExt, PrimeField}, + arithmetic::{fe_to_fe, CurveAffine, FieldExt, PrimeField}, hash::Poseidon, transcript::{Transcript, TranscriptRead, TranscriptWrite}, Itertools, @@ -57,10 +57,11 @@ impl< > PoseidonTranscript>, Value, T, RATE, R_F, R_P> { pub fn new(loader: &Rc>, stream: Value) -> Self { + let buf = Poseidon::new(loader, R_F, R_P); Self { loader: loader.clone(), stream, - buf: Poseidon::new(loader.clone(), R_F, R_P), + buf, _marker: PhantomData, } } @@ -99,7 +100,7 @@ impl< .map(|encoded| { encoded .into_iter() - .map(|encoded| self.loader.scalar(halo2::loader::Value::Assigned(encoded))) + .map(|encoded| self.loader.scalar_from_assigned(encoded)) .collect_vec() }) .map_err(|_| Error::Transcript(io::ErrorKind::Other, "".to_string()))?; @@ -160,7 +161,7 @@ impl = Option::from(ec_point.coordinates().map(|coordinates| { [coordinates.x(), coordinates.y()] .into_iter() - .map(|fe| fe_from_big(fe_to_big(*fe))) + .cloned() + .map(fe_to_fe) .collect_vec() })) .ok_or_else(|| { diff --git a/src/util/hash/poseidon.rs b/src/util/hash/poseidon.rs index 878b69ce..c0fc03a6 100644 --- a/src/util/hash/poseidon.rs +++ b/src/util/hash/poseidon.rs @@ -113,7 +113,7 @@ pub struct Poseidon { } impl, const T: usize, const RATE: usize> Poseidon { - pub fn new(loader: L::Loader, r_f: usize, r_p: usize) -> Self { + pub fn new(loader: &L::Loader, r_f: usize, r_p: usize) -> Self { Self { spec: Spec::new(r_f, r_p), state: State::new( diff --git a/src/util/msm.rs b/src/util/msm.rs index a7a3d45d..a15eab9d 100644 --- a/src/util/msm.rs +++ b/src/util/msm.rs @@ -1,6 +1,6 @@ use crate::{ loader::{LoadedEcPoint, Loader}, - util::arithmetic::CurveAffine, + util::{arithmetic::CurveAffine, Itertools}, }; use std::{ default::Default, @@ -71,11 +71,11 @@ where .loader() .ec_point_load_const(&gen) }); - L::LoadedEcPoint::multi_scalar_multiplication( - iter::empty() - .chain(self.constant.map(|constant| (constant, gen.unwrap()))) - .chain(self.scalars.into_iter().zip(self.bases.into_iter())), - ) + let pairs = iter::empty() + .chain(self.constant.map(|constant| (constant, gen.unwrap()))) + .chain(self.scalars.into_iter().zip(self.bases.into_iter())) + .collect_vec(); + L::multi_scalar_multiplication(&pairs) } pub fn scale(&mut self, factor: &L::LoadedScalar) { diff --git a/src/util/protocol.rs b/src/util/protocol.rs index bae363ff..e9ecd9f0 100644 --- a/src/util/protocol.rs +++ b/src/util/protocol.rs @@ -4,6 +4,7 @@ use crate::{ arithmetic::{CurveAffine, Domain, Field, Fraction, Rotation}, Itertools, }, + Protocol, }; use num_integer::Integer; use num_traits::One; @@ -15,6 +16,37 @@ use std::{ ops::{Add, Mul, Neg, Sub}, }; +impl Protocol +where + C: CurveAffine, +{ + pub fn loaded>(&self, loader: &L) -> Protocol { + let preprocessed = self + .preprocessed + .iter() + .map(|preprocessed| loader.ec_point_load_const(preprocessed)) + .collect(); + let transcript_initial_state = self + .transcript_initial_state + .as_ref() + .map(|transcript_initial_state| loader.load_const(transcript_initial_state)); + Protocol { + domain: self.domain.clone(), + preprocessed, + num_instance: self.num_instance.clone(), + num_witness: self.num_witness.clone(), + num_challenge: self.num_challenge.clone(), + evaluations: self.evaluations.clone(), + queries: self.queries.clone(), + quotient: self.quotient.clone(), + transcript_initial_state, + instance_committing_key: self.instance_committing_key.clone(), + linearization: self.linearization.clone(), + accumulator_indices: self.accumulator_indices.clone(), + } + } +} + #[derive(Clone, Copy, Debug)] pub enum CommonPolynomial { Identity, diff --git a/src/verifier.rs b/src/verifier.rs index 07529603..0eef23d2 100644 --- a/src/verifier.rs +++ b/src/verifier.rs @@ -20,7 +20,7 @@ where fn read_proof( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], transcript: &mut T, ) -> Result @@ -29,7 +29,7 @@ where fn succinct_verify( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], proof: &Self::Proof, ) -> Result, Error>; @@ -37,7 +37,7 @@ where fn verify( svk: &MOS::SuccinctVerifyingKey, dk: &MOS::DecidingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], proof: &Self::Proof, ) -> Result diff --git a/src/verifier/plonk.rs b/src/verifier/plonk.rs index 9e08e585..c3276af4 100644 --- a/src/verifier/plonk.rs +++ b/src/verifier/plonk.rs @@ -29,7 +29,7 @@ where fn read_proof( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], transcript: &mut T, ) -> Result @@ -41,7 +41,7 @@ where fn succinct_verify( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], proof: &Self::Proof, ) -> Result, Error> { @@ -52,7 +52,7 @@ where &proof.z, ); - L::LoadedScalar::batch_invert(common_poly_eval.denoms()); + L::batch_invert(common_poly_eval.denoms()); common_poly_eval.evaluate(); common_poly_eval @@ -96,9 +96,9 @@ where L: Loader, MOS: MultiOpenScheme, { - fn read( + pub fn read( svk: &MOS::SuccinctVerifyingKey, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], transcript: &mut T, ) -> Result @@ -106,9 +106,8 @@ where T: TranscriptRead, AE: AccumulatorEncoding, { - let loader = transcript.loader(); if let Some(transcript_initial_state) = &protocol.transcript_initial_state { - transcript.common_scalar(&loader.load_const(transcript_initial_state))?; + transcript.common_scalar(transcript_initial_state)?; } if protocol.num_instance @@ -211,7 +210,7 @@ where }) } - fn empty_queries(protocol: &Protocol) -> Vec> { + pub fn empty_queries(protocol: &Protocol) -> Vec> { protocol .queries .iter() @@ -227,7 +226,7 @@ where fn queries( &self, - protocol: &Protocol, + protocol: &Protocol, mut evaluations: HashMap, ) -> Vec> { Self::empty_queries(protocol) @@ -244,7 +243,7 @@ where fn commitments( &self, - protocol: &Protocol, + protocol: &Protocol, common_poly_eval: &CommonPolynomialEvaluation, evaluations: &mut HashMap, ) -> Result>, Error> { @@ -254,7 +253,7 @@ where protocol .preprocessed .iter() - .map(|value| Msm::base(loader.ec_point_load_const(value))), + .map(|value| Msm::base(value.clone())), ) .chain( self.committed_instances @@ -357,7 +356,7 @@ where fn evaluations( &self, - protocol: &Protocol, + protocol: &Protocol, instances: &[Vec], common_poly_eval: &CommonPolynomialEvaluation, ) -> Result, Error> { @@ -424,9 +423,13 @@ where } } -fn langranges(protocol: &Protocol, instances: &[Vec]) -> impl IntoIterator +fn langranges( + protocol: &Protocol, + instances: &[Vec], +) -> impl IntoIterator where C: CurveAffine, + L: Loader, { let instance_eval_lagrange = protocol.instance_committing_key.is_none().then(|| { let queries = {