diff --git a/halo2-ecc/src/bn254/pairing.rs b/halo2-ecc/src/bn254/pairing.rs index 08e2fe06..8fded38c 100644 --- a/halo2-ecc/src/bn254/pairing.rs +++ b/halo2-ecc/src/bn254/pairing.rs @@ -2,6 +2,7 @@ use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint, FqPoint}; use crate::fields::PrimeField; use crate::halo2_proofs::{ + arithmetic::CurveAffine, circuit::Value, halo2curves::bn256::{self, G1Affine, G2Affine, SIX_U_PLUS_2_NAF}, halo2curves::bn256::{Fq, Fq2, FROBENIUS_COEFF_FQ12_C1}, @@ -215,13 +216,16 @@ pub fn fp12_multiply_with_line_equal( // - `0 <= loop_count < r` and `loop_count < p` (to avoid [loop_count]Q' = Frob_p(Q')) // - x^3 + b = 0 has no solution in Fp2, i.e., the y-coordinate of Q cannot be 0. -pub fn miller_loop_BN( +pub fn miller_loop_BN( ecc_chip: &EccChip>, ctx: &mut Context, Q: &EcPoint>, P: &EcPoint>, pseudo_binary_encoding: &[i8], -) -> FqPoint { +) -> FqPoint +where + C: CurveAffine, +{ let mut i = pseudo_binary_encoding.len() - 1; while pseudo_binary_encoding[i] == 0 { i -= 1; @@ -262,7 +266,7 @@ pub fn miller_loop_BN( let f_sq = fp12_chip.mul(ctx, &f, &f); f = fp12_multiply_with_line_equal::(ecc_chip.field_chip(), ctx, &f_sq, &R, P); } - R = ecc_chip.double(ctx, &R); + R = ecc_chip.double::(ctx, &R); assert!(pseudo_binary_encoding[i] <= 1 && pseudo_binary_encoding[i] >= -1); if pseudo_binary_encoding[i] != 0 { @@ -300,12 +304,15 @@ pub fn miller_loop_BN( // let pairs = [(a_i, b_i)], a_i in G_1, b_i in G_2 // output is Prod_i e'(a_i, b_i), where e'(a_i, b_i) is the output of `miller_loop_BN(b_i, a_i)` -pub fn multi_miller_loop_BN( +pub fn multi_miller_loop_BN( ecc_chip: &EccChip>, ctx: &mut Context, pairs: Vec<(&EcPoint>, &EcPoint>)>, pseudo_binary_encoding: &[i8], -) -> FqPoint { +) -> FqPoint +where + C: CurveAffine, +{ let mut i = pseudo_binary_encoding.len() - 1; while pseudo_binary_encoding[i] == 0 { i -= 1; @@ -354,7 +361,7 @@ pub fn multi_miller_loop_BN( } } for r in r.iter_mut() { - *r = ecc_chip.double(ctx, &r); + *r = ecc_chip.double::(ctx, &r); } assert!(pseudo_binary_encoding[i] <= 1 && pseudo_binary_encoding[i] >= -1); @@ -517,7 +524,7 @@ impl PairingChip { ) -> FqPoint { let fp2_chip = Fp2Chip::::construct(self.fp_chip.clone()); let g2_chip = EccChip::construct(fp2_chip); - miller_loop_BN::( + miller_loop_BN::( &g2_chip, ctx, Q, @@ -533,7 +540,7 @@ impl PairingChip { ) -> FqPoint { let fp2_chip = Fp2Chip::::construct(self.fp_chip.clone()); let g2_chip = EccChip::construct(fp2_chip); - multi_miller_loop_BN::( + multi_miller_loop_BN::( &g2_chip, ctx, pairs, diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index c7239d9d..424acacf 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -1,3 +1,4 @@ +use ff::Field; use std::{env::var, fs::File}; #[allow(unused_imports)] diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index e2d3d716..4e68e6f7 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -1,6 +1,6 @@ -use std::{env::var, fs::File}; - +use ff::Field; use halo2_base::SKIP_FIRST_PASS; +use std::{env::var, fs::File}; use super::*; diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index 6f940874..e81ea664 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -56,7 +56,7 @@ where base_chip.limb_bits, fixed_window_bits, ); - let u2_mul = scalar_multiply::( + let u2_mul = scalar_multiply::( base_chip, ctx, pubkey, diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index 972f0f6c..3b1586b2 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -138,23 +138,31 @@ pub fn ec_sub_unequal>( // formula from https://crypto.stanford.edu/pbc/notes/elliptic/explicit.html // assume y != 0 (otherwise 2P = O) -// lamb = 3x^2 / (2 y) % p +// lamb = 3x^2 + a / (2 y) % p // x_3 = out[0] = lambda^2 - 2 x % p // y_3 = out[1] = lambda (x - x_3) - y % p -// we precompute lambda and constrain (2y) * lambda = 3 x^2 (mod p) +// we precompute lambda and constrain (2y) * lambda = 3 x^2 + a(mod p) // then we compute x_3 = lambda^2 - 2 x (mod p) // y_3 = lambda (x - x_3) - y (mod p) -pub fn ec_double>( +pub fn ec_double, C>( chip: &FC, ctx: &mut Context, P: &EcPoint, -) -> EcPoint { +) -> EcPoint +where + C: CurveAffine, +{ // removed optimization that computes `2 * lambda` while assigning witness to `lambda` simultaneously, in favor of readability. The difference is just copying `lambda` once let two_y = chip.scalar_mul_no_carry(ctx, &P.y, 2); let three_x = chip.scalar_mul_no_carry(ctx, &P.x, 3); let three_x_sq = chip.mul_no_carry(ctx, &three_x, &P.x); - let lambda = chip.divide_unsafe(ctx, &three_x_sq, &two_y); + + // add a, for secp256k1 a = 0, for secp256r1, a > 0 + let a_const = FC::fe_to_constant(C::a()); + let three_x_plus_a = chip.add_constant_no_carry(ctx, &three_x_sq, a_const); + + let lambda = chip.divide_unsafe(ctx, &three_x_plus_a, &two_y); // x_3 = lambda^2 - 2 x % p let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); @@ -229,7 +237,7 @@ where // - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) // - `max_bits <= modulus::.bits()` // * P has order given by the scalar field modulus -pub fn scalar_multiply( +pub fn scalar_multiply( chip: &FC, ctx: &mut Context, P: &EcPoint, @@ -239,6 +247,7 @@ pub fn scalar_multiply( ) -> EcPoint where FC: FieldChip + Selectable, + C: CurveAffineExt, { assert!(!scalar.is_empty()); assert!((max_bits as u64) <= modulus::().bits()); @@ -292,7 +301,7 @@ where cached_points.push(P.clone()); for idx in 2..cache_size { if idx == 2 { - let double = ec_double(chip, ctx, P /*, b*/); + let double = ec_double::(chip, ctx, P /*, b*/); cached_points.push(double.clone()); } else { let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], P, false); @@ -311,7 +320,7 @@ where for idx in 1..num_windows { let mut mult_point = curr_point.clone(); for _ in 0..window_bits { - mult_point = ec_double(chip, ctx, &mult_point); + mult_point = ec_double::(chip, ctx, &mult_point); } let add_point = ec_select_from_bits::( chip, @@ -430,7 +439,7 @@ where let mut rand_start_vec = Vec::with_capacity(k + window_bits); rand_start_vec.push(base); for idx in 1..(k + window_bits) { - let base_mult = ec_double(chip, ctx, &rand_start_vec[idx - 1]); + let base_mult = ec_double::(chip, ctx, &rand_start_vec[idx - 1]); rand_start_vec.push(base_mult); } assert!(rand_start_vec.len() >= k + window_bits); @@ -481,7 +490,7 @@ where // compute \sum_i x_i P_i + (2^{k + 1} - 1) * A for idx in 0..num_windows { for _ in 0..window_bits { - curr_point = ec_double(chip, ctx, &curr_point); + curr_point = ec_double::(chip, ctx, &curr_point); } for (cached_points, rounded_bits) in cached_points.chunks(cache_size).zip(rounded_bits.chunks(rounded_bitlen)) @@ -687,12 +696,15 @@ impl> EccChip { ec_sub_unequal(&self.field_chip, ctx, P, Q, is_strict) } - pub fn double( + pub fn double( &self, ctx: &mut Context, P: &EcPoint, - ) -> EcPoint { - ec_double(&self.field_chip, ctx, P) + ) -> EcPoint + where + C: CurveAffine, + { + ec_double::(&self.field_chip, ctx, P) } pub fn is_equal( @@ -751,15 +763,18 @@ where ec_select(&self.field_chip, ctx, P, Q, condition) } - pub fn scalar_mult( + pub fn scalar_mult( &self, ctx: &mut Context, P: &EcPoint, scalar: &Vec>, max_bits: usize, window_bits: usize, - ) -> EcPoint { - scalar_multiply::(&self.field_chip, ctx, P, scalar, max_bits, window_bits) + ) -> EcPoint + where + C: CurveAffine, + { + scalar_multiply::(&self.field_chip, ctx, P, scalar, max_bits, window_bits) } // TODO: put a check in place that scalar is < modulus of C::Scalar diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index b713966e..2871421c 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -12,7 +12,7 @@ use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, // Output: // * new_points: length `points.len() * radix` // * new_bool_scalars: 2d array `ceil(scalar_bits / radix)` by `points.len() * radix` -pub fn decompose( +pub fn decompose( chip: &FC, ctx: &mut Context, points: &[EcPoint], @@ -23,6 +23,7 @@ pub fn decompose( where F: PrimeField, FC: FieldChip, + C: CurveAffineExt, { assert_eq!(points.len(), scalars.len()); let scalar_bits = max_scalar_bits_per_cell * scalars[0].len(); @@ -38,7 +39,7 @@ where new_points.push(g); for _ in 1..radix { // if radix > 1, this does not work if `points` contains identity point - g = ec_double(chip, ctx, new_points.last().unwrap()); + g = ec_double::(chip, ctx, new_points.last().unwrap()); new_points.push(g); } let mut bits = Vec::with_capacity(scalar_bits); @@ -88,7 +89,7 @@ where // for later addition collision-prevension, we need a different random point per round // we take 2^round * rand_base if round > 0 { - rand_point = ec_double(chip, ctx, &rand_point); + rand_point = ec_double::(chip, ctx, &rand_point); } // stores { rand_point, rand_point + points[0], rand_point + points[1], rand_point + points[0] + points[1] , ... } // since rand_point is random, we can always use add_unequal (with strict constraint checking that the points are indeed unequal and not negative of each other) @@ -129,7 +130,7 @@ where } // we have acc[j] = G'[j] + (2^num_rounds - 1) * rand_base - rand_point = ec_double(chip, ctx, &rand_point); + rand_point = ec_double::(chip, ctx, &rand_point); rand_point = ec_sub_unequal(chip, ctx, &rand_point, &rand_base, false); (acc, rand_point) @@ -149,7 +150,7 @@ where C: CurveAffineExt, { let (points, bool_scalars) = - decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); + decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); /* let t = bool_scalars.len(); @@ -179,8 +180,8 @@ where let mut rand_sum = rand_point.clone(); for g in agg.iter().rev() { for _ in 0..radix { - sum = ec_double(chip, ctx, &sum); - rand_sum = ec_double(chip, ctx, &rand_sum); + sum = ec_double::(chip, ctx, &sum); + rand_sum = ec_double::(chip, ctx, &rand_sum); } sum = ec_add_unequal(chip, ctx, &sum, g, true); chip.enforce_less_than(ctx, sum.x()); @@ -192,7 +193,7 @@ where } if radix == 1 { - rand_sum = ec_double(chip, ctx, &rand_sum); + rand_sum = ec_double::(chip, ctx, &rand_sum); // assume 2^t != +-1 mod modulus::() rand_sum = ec_sub_unequal(chip, ctx, &rand_sum, &rand_point, false); } diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index 8fe0c382..1b5589f5 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -110,7 +110,7 @@ impl Circuit for MyCircuit { // test double { - let doub = chip.double(ctx, &P_assigned); + let doub = chip.double::(ctx, &P_assigned); assert_eq!( value_to_option(doub.x.truncation.to_bigint(config.limb_bits)), value_to_option(doub.x.value.clone()) diff --git a/halo2-ecc/src/fields/tests.rs b/halo2-ecc/src/fields/tests.rs index eb8de39b..26c1f778 100644 --- a/halo2-ecc/src/fields/tests.rs +++ b/halo2-ecc/src/fields/tests.rs @@ -10,6 +10,7 @@ mod fp { halo2curves::bn256::{Fq, Fr}, plonk::*, }; + use ff::Field; use halo2_base::{ utils::{fe_to_biguint, modulus}, SKIP_FIRST_PASS, @@ -141,6 +142,7 @@ mod fp12 { halo2curves::bn256::{Fq, Fq12, Fr}, plonk::*, }; + use ff::Field; use halo2_base::utils::modulus; use halo2_base::SKIP_FIRST_PASS; use std::marker::PhantomData; diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index 6b17d91b..ad9a869f 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -1,6 +1,7 @@ #![allow(non_snake_case)] use crate::fields::PrimeField; use ark_std::{end_timer, start_timer}; +use ff::Field; use halo2_base::SKIP_FIRST_PASS; use serde::{Deserialize, Serialize}; use std::fs::File;