diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index 6603cede..a8f039c2 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -19,6 +19,7 @@ use halo2_base::{ halo2_proofs::halo2curves::bn256::G1, utils::fs::gen_srs, }; +use itertools::Itertools; use rand_core::OsRng; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -68,6 +69,7 @@ fn fixed_base_msm_test( fn random_fixed_base_msm_circuit( params: MSMCircuitParams, + bases: Vec, // bases are fixed in vkey so don't randomly generate stage: CircuitBuilderStage, break_points: Option, ) -> RangeCircuitBuilder { @@ -78,8 +80,7 @@ fn random_fixed_base_msm_circuit( CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), }; - let (bases, scalars): (Vec<_>, Vec<_>) = - (0..params.batch_size).map(|_| (G1Affine::random(OsRng), Fr::random(OsRng))).unzip(); + let scalars = (0..params.batch_size).map(|_| Fr::random(OsRng)).collect_vec(); let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); fixed_base_msm_test(&mut builder, params, bases, scalars); @@ -106,7 +107,8 @@ fn test_fixed_base_msm() { ) .unwrap(); - let circuit = random_fixed_base_msm_circuit(params, CircuitBuilderStage::Mock, None); + let bases = (0..params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); + let circuit = random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -132,8 +134,13 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let params = gen_srs(k); println!("{bench_params:?}"); - let circuit = - random_fixed_base_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None); + let bases = (0..bench_params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); + let circuit = random_fixed_base_msm_circuit( + bench_params, + bases.clone(), + CircuitBuilderStage::Keygen, + None, + ); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; @@ -149,6 +156,7 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let proof_time = start_timer!(|| "Proving time"); let circuit = random_fixed_base_msm_circuit( bench_params, + bases, CircuitBuilderStage::Prover, Some(break_points), ); diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index 0f37a71a..d7406a17 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -49,6 +49,7 @@ where u1.limbs().to_vec(), base_chip.limb_bits, fixed_window_bits, + true, // we can call it with scalar_is_safe = true because of the u1_small check below ); let u2_mul = scalar_multiply( base_chip, diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index de9e8d86..dc67b8d6 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,5 +1,6 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; +use crate::ecc::ec_sub_strict; use crate::fields::{FieldChip, PrimeField, Selectable}; use group::Curve; use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; @@ -8,14 +9,17 @@ use itertools::Itertools; use rayon::prelude::*; use std::cmp::min; -// computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant) -// - `scalar` is represented as a non-empty reference array of `AssignedValue`s -// - `scalar = sum_i scalar_i * 2^{max_bits * i}` -// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` -// assumes: -// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) -// - `max_bits <= modulus::.bits()` - +/// Computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant) +/// - `scalar` is represented as a non-empty reference array of `AssignedValue`s +/// - `scalar = sum_i scalar_i * 2^{max_bits * i}` +/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` +/// +/// # Assumptions +/// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) +/// - `scalar > 0` +/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) +/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) +/// - `max_bits <= modulus::.bits()` pub fn scalar_multiply( chip: &FC, ctx: &mut Context, @@ -23,6 +27,7 @@ pub fn scalar_multiply( scalar: Vec>, max_bits: usize, window_bits: usize, + scalar_is_safe: bool, ) -> EcPoint where F: PrimeField, @@ -33,8 +38,8 @@ where let zero = chip.load_constant(ctx, C::Base::zero()); return EcPoint::new(zero.clone(), zero); } - debug_assert!(!scalar.is_empty()); - debug_assert!((max_bits as u32) <= F::NUM_BITS); + assert!(!scalar.is_empty()); + assert!((max_bits as u32) <= F::NUM_BITS); let total_bits = max_bits * scalar.len(); let num_windows = (total_bits + window_bits - 1) / window_bits; @@ -91,7 +96,7 @@ where let is_zero_window = chip.gate().is_zero(ctx, bit_sum); let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, false); + let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, !scalar_is_safe); let zero_sum = ec_select(chip, ctx, curr_point, sum, is_zero_window); Some(ec_select(chip, ctx, zero_sum, add_point, is_started)) } else { @@ -107,117 +112,16 @@ where curr_point.unwrap() } -/* To reduce total amount of code, just always use msm_par below. // basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation // we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO) -pub fn msm( - chip: &EccChip, - ctx: &mut Context, - points: &[C], - scalars: Vec>>, - max_scalar_bits_per_cell: usize, - window_bits: usize, -) -> EcPoint -where - F: PrimeField, - C: CurveAffineExt, - FC: FieldChip + Selectable, -{ - assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS); - let scalar_len = scalars[0].len(); - let total_bits = max_scalar_bits_per_cell * scalar_len; - let num_windows = (total_bits + window_bits - 1) / window_bits; - - // `cached_points` is a flattened 2d vector - // first we compute all cached points in Jacobian coordinates since it's fastest - let cached_points_jacobian = points - .iter() - .flat_map(|point| { - let base_pt = point.to_curve(); - // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1} - let mut increment = base_pt; - (0..num_windows) - .flat_map(|i| { - let mut curr = increment; - let cache_vec = std::iter::once(increment) - .chain((1..(1usize << min(window_bits, total_bits - i * window_bits))).map( - |_| { - let prev = curr; - curr += increment; - prev - }, - )) - .collect_vec(); - increment = curr; - cache_vec - }) - .collect_vec() - }) - .collect_vec(); - // for use in circuits we need affine coordinates, so we do a batch normalize: this is much more efficient than calling `to_affine` one by one since field inversion is very expensive - // initialize to all 0s - let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()]; - C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); - - let field_chip = chip.field_chip(); - let cached_points = cached_points_affine - .into_iter() - .map(|point| chip.assign_constant_point(ctx, point)) - .collect_vec(); - - let bits = scalars - .into_iter() - .flat_map(|scalar| { - assert_eq!(scalar.len(), scalar_len); - scalar - .into_iter() - .flat_map(|scalar_chunk| { - field_chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell) - }) - .collect_vec() - }) - .collect_vec(); - - let scalar_mults = cached_points - .chunks(cached_points.len() / points.len()) - .zip(bits.chunks(total_bits)) - .map(|(cached_points, bits)| { - let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); - let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = ctx.load_zero(); - for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { - let is_zero_window = { - let sum = field_chip.gate().sum(ctx, bit_window.iter().copied()); - field_chip.gate().is_zero(ctx, sum) - }; - let add_point = - ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(field_chip, ctx, curr_point, sum, is_zero_window); - Some(ec_select(field_chip, ctx, zero_sum, add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = field_chip.gate().not(ctx, is_zero_window); - field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) - }; - } - curr_point.unwrap() - }) - .collect_vec(); - chip.sum::(ctx, scalar_mults) -} -*/ /// # Assumptions /// * `points.len() = scalars.len()` /// * `scalars[i].len() = scalars[j].len()` for all `i,j` +/// * `points` are all on the curve +/// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand +/// * The integer value of `scalars[i]` is less than the order of `points[i]` (some constraints may fail otherwise) +/// * Output may be point at infinity, in which case (0, 0) is returned pub fn msm_par( chip: &EccChip, builder: &mut GateThreadBuilder, @@ -232,6 +136,9 @@ where C: CurveAffineExt, FC: FieldChip + Selectable, { + if points.is_empty() { + return chip.assign_constant_point(builder.main(phase), C::identity()); + } assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS); assert_eq!(points.len(), scalars.len()); assert!(!points.is_empty(), "fixed_base::msm_par requires at least one point"); @@ -306,6 +213,7 @@ where let add_point = ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); curr_point = if let Some(curr_point) = curr_point { + // We don't need strict mode because we assume scalars[i] is less than the order of points[i] let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); let zero_sum = ec_select(field_chip, ctx, curr_point, sum, is_zero_window); Some(ec_select(field_chip, ctx, zero_sum, add_point, is_started)) @@ -319,8 +227,16 @@ where field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) }; } - curr_point.unwrap() + (curr_point.unwrap(), is_started) }, ); - chip.sum::(builder.main(phase), scalar_mults) + let ctx = builder.main(phase); + // sum `scalar_mults` but take into account possiblity of identity points + let any_point = chip.load_random_point::(ctx); + let mut acc = any_point.clone(); + for (point, is_not_identity) in scalar_mults { + let new_acc = chip.add_unequal(ctx, &acc, point, true); + acc = chip.select(ctx, new_acc, acc, is_not_identity); + } + ec_sub_strict(field_chip, ctx, acc, any_point) } diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index eea89a31..a4dedd5f 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -469,11 +469,12 @@ where /// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` /// /// # Assumptions -/// * `P` is not the point at infinity -/// * `scalar > 0` -/// * If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) -/// * `scalar_i < 2^{max_bits} for all i` -/// * `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` +/// - `P` is not the point at infinity +/// - `scalar > 0` +/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) +/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) +/// - `scalar_i < 2^{max_bits} for all i` +/// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` pub fn scalar_multiply( chip: &FC, ctx: &mut Context, @@ -1094,6 +1095,7 @@ where } impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { + /// See [`fixed_base::scalar_multiply`] for more details. // TODO: put a check in place that scalar is < modulus of C::Scalar pub fn fixed_base_scalar_mult( &self, @@ -1102,6 +1104,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar: Vec>, max_bits: usize, window_bits: usize, + scalar_is_safe: bool, ) -> EcPoint where C: CurveAffineExt, @@ -1114,6 +1117,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar, max_bits, window_bits, + scalar_is_safe, ) }