diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8df344f3..1ea56f6e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,11 +27,12 @@ jobs: cd halo2-ecc cargo test -- --test-threads=1 test_fp cargo test -- test_ecc - cargo test -- test_secp256k1_ecdsa + cargo test -- test_secp cargo test -- test_ecdsa cargo test -- test_ec_add - cargo test -- test_fixed_base_msm + cargo test -- test_fixed cargo test -- test_msm + cargo test -- test_fb cargo test -- test_pairing cd .. - name: Run halo2-ecc tests real prover diff --git a/halo2-base/src/utils.rs b/halo2-base/src/utils.rs index f98a28da..69c1a1f9 100644 --- a/halo2-base/src/utils.rs +++ b/halo2-base/src/utils.rs @@ -81,7 +81,7 @@ pub(crate) fn decompose_u64_digits_to_limbs( core::cmp::Ordering::Less => { let mut limb = u64_digit; u64_digit = e.next().unwrap_or(0); - limb |= (u64_digit & ((1 << (bit_len - rem)) - 1)) << rem; + limb |= (u64_digit & ((1u64 << (bit_len - rem)) - 1u64)) << rem; u64_digit >>= bit_len - rem; rem += 64 - bit_len; limb @@ -218,7 +218,7 @@ pub fn decompose_biguint( let mut rem = bit_len - 64; let mut u64_digit = e.next().unwrap_or(0); // Extract second limb (bit length 64) from e - limb0 |= ((u64_digit & ((1 << rem) - 1u64)) as u128) << 64u32; + limb0 |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << 64u32; u64_digit >>= rem; rem = 64 - rem; @@ -234,7 +234,7 @@ pub fn decompose_biguint( bits += 64; } rem = bit_len - bits; - limb |= ((u64_digit & ((1 << rem) - 1)) as u128) << bits; + limb |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << bits; u64_digit >>= rem; rem = 64 - rem; F::from_u128(limb) diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index a8f039c2..0283f672 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -23,7 +23,7 @@ use itertools::Itertools; use rand_core::OsRng; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { +struct FixedMSMCircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -39,7 +39,7 @@ struct MSMCircuitParams { fn fixed_base_msm_test( builder: &mut GateThreadBuilder, - params: MSMCircuitParams, + params: FixedMSMCircuitParams, bases: Vec, scalars: Vec, ) { @@ -68,7 +68,7 @@ fn fixed_base_msm_test( } fn random_fixed_base_msm_circuit( - params: MSMCircuitParams, + params: FixedMSMCircuitParams, bases: Vec, // bases are fixed in vkey so don't randomly generate stage: CircuitBuilderStage, break_points: Option, @@ -102,7 +102,7 @@ fn random_fixed_base_msm_circuit( #[test] fn test_fixed_base_msm() { let path = "configs/bn254/fixed_msm_circuit.config"; - let params: MSMCircuitParams = serde_json::from_reader( + let params: FixedMSMCircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); @@ -112,6 +112,23 @@ fn test_fixed_base_msm() { MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } +#[test] +fn test_fixed_msm_minus_1() { + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + let base = G1Affine::random(OsRng); + let k = params.degree as usize; + let mut builder = GateThreadBuilder::mock(); + fixed_base_msm_test(&mut builder, params, vec![base], vec![-Fr::one()]); + + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + #[test] fn bench_fixed_base_msm() -> Result<(), Box> { let config_path = "configs/bn254/bench_fixed_msm.config"; @@ -126,7 +143,8 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { - let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); + let bench_params: FixedMSMCircuitParams = + serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); let rng = OsRng; diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index b373d51e..172300a1 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -1,20 +1,23 @@ #![allow(non_snake_case)] use super::pairing::PairingChip; use super::*; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, - plonk::*, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, +use crate::{ecc::EccChip, fields::PrimeField}; +use crate::{ + fields::FpStrategy, + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, + plonk::*, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + transcript::{Blake2bRead, Blake2bWrite, Challenge255}, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; -use crate::{ecc::EccChip, fields::PrimeField}; use ark_std::{end_timer, start_timer}; use group::Curve; use halo2_base::utils::fe_to_biguint; @@ -24,4 +27,20 @@ use std::io::Write; pub mod ec_add; pub mod fixed_base_msm; pub mod msm; +pub mod msm_sum_infinity; +pub mod msm_sum_infinity_fixed_base; pub mod pairing; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct MSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + window_bits: usize, +} diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs new file mode 100644 index 00000000..600a4931 --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -0,0 +1,183 @@ +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; +use rand_core::OsRng; +use std::fs::File; + +use super::*; + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); + + let msm = ecc_chip.variable_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs new file mode 100644 index 00000000..6cf96c7f --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -0,0 +1,183 @@ +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; +use rand_core::OsRng; +use std::fs::File; + +use super::*; + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases; + //.iter() + //.map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + //.collect::>(); + + let msm = ecc_chip.fixed_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases_assigned + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_fb_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index d7406a17..ca0b111b 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -3,8 +3,7 @@ use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, use crate::bigint::{big_is_equal, big_less_than, FixedOverflowInteger, ProperCrtUint}; use crate::fields::{fp::FpChip, FieldChip, PrimeField}; -use super::{fixed_base, EccChip}; -use super::{scalar_multiply, EcPoint}; +use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // CF is the coordinate field of GA // SF is the scalar field of GA // p = coordinate field modulus @@ -12,6 +11,7 @@ use super::{scalar_multiply, EcPoint}; // Only valid when p is very close to n in size (e.g. for Secp256k1) // Assumes `r, s` are proper CRT integers /// **WARNING**: Only use this function if `1 / (p - n)` is very small (e.g., < 2-100) +/// `pubkey` should not be the identity point pub fn ecdsa_verify_no_pubkey_check( chip: &EccChip>, ctx: &mut Context, @@ -49,16 +49,14 @@ where u1.limbs().to_vec(), base_chip.limb_bits, fixed_window_bits, - true, // we can call it with scalar_is_safe = true because of the u1_small check below ); - let u2_mul = scalar_multiply( + let u2_mul = scalar_multiply::<_, _, GA>( base_chip, ctx, pubkey, u2.limbs().to_vec(), base_chip.limb_bits, var_window_bits, - true, // we can call it with scalar_is_safe = true because of the u2_small check below ); // check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index dc67b8d6..5dfba754 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,6 +1,6 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; -use crate::ecc::ec_sub_strict; +use crate::ecc::{ec_sub_strict, load_random_point}; use crate::fields::{FieldChip, PrimeField, Selectable}; use group::Curve; use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; @@ -17,8 +17,6 @@ use std::cmp::min; /// # Assumptions /// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) /// - `scalar > 0` -/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) -/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) /// - `max_bits <= modulus::.bits()` pub fn scalar_multiply( chip: &FC, @@ -27,7 +25,6 @@ pub fn scalar_multiply( scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where F: PrimeField, @@ -87,29 +84,19 @@ where let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = ctx.load_zero(); + let any_point = load_random_point::(chip, ctx); + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied()); // are we just adding a window of all 0s? if so, skip let is_zero_window = chip.gate().is_zero(ctx, bit_sum); - let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, !scalar_is_safe); - let zero_sum = ec_select(chip, ctx, curr_point, sum, is_zero_window); - Some(ec_select(chip, ctx, zero_sum, add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = chip.gate().not(ctx, is_zero_window); - chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) + curr_point = { + let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, true); + ec_select(chip, ctx, curr_point, sum, is_zero_window) }; } - curr_point.unwrap() + ec_sub_strict(chip, ctx, curr_point, any_point) } // basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation @@ -120,7 +107,7 @@ where /// * `scalars[i].len() = scalars[j].len()` for all `i,j` /// * `points` are all on the curve /// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand -/// * The integer value of `scalars[i]` is less than the order of `points[i]` (some constraints may fail otherwise) +/// * The integer value of `scalars[i]` is less than the order of `points[i]` /// * Output may be point at infinity, in which case (0, 0) is returned pub fn msm_par( chip: &EccChip, @@ -153,6 +140,7 @@ where .flat_map(|point| -> Vec<_> { let base_pt = point.to_curve(); // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1} + // EXCEPT cached_points[idx][0] = points[idx] let mut increment = base_pt; (0..num_windows) .flat_map(|i| { @@ -178,8 +166,9 @@ where C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); + let ctx = builder.main(phase); + let any_point = chip.load_random_point::(ctx); - let zero = builder.main(phase).load_zero(); let scalar_mults = parallelize_in( phase, builder, @@ -202,41 +191,29 @@ where }) .collect::>(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = zero; + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let is_zero_window = { let sum = field_chip.gate().sum(ctx, bit_window.iter().copied()); field_chip.gate().is_zero(ctx, sum) }; - let add_point = - ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - // We don't need strict mode because we assume scalars[i] is less than the order of points[i] - let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(field_chip, ctx, curr_point, sum, is_zero_window); - Some(ec_select(field_chip, ctx, zero_sum, add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = field_chip.gate().not(ctx, is_zero_window); - field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) + curr_point = { + let add_point = + ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, true); + ec_select(field_chip, ctx, curr_point, sum, is_zero_window) }; } - (curr_point.unwrap(), is_started) + curr_point }, ); let ctx = builder.main(phase); // sum `scalar_mults` but take into account possiblity of identity points - let any_point = chip.load_random_point::(ctx); - let mut acc = any_point.clone(); - for (point, is_not_identity) in scalar_mults { + let any_point2 = chip.load_random_point::(ctx); + let mut acc = any_point2.clone(); + for point in scalar_mults { let new_acc = chip.add_unequal(ctx, &acc, point, true); - acc = chip.select(ctx, new_acc, acc, is_not_identity); + acc = chip.sub_unequal(ctx, new_acc, &any_point, true); } - ec_sub_strict(field_chip, ctx, acc, any_point) + ec_sub_strict(field_chip, ctx, acc, any_point2) } diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index a4dedd5f..d63b4c4a 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -3,9 +3,10 @@ use crate::fields::{fp::FpChip, FieldChip, PrimeField, Selectable}; use crate::halo2_proofs::arithmetic::CurveAffine; use group::{Curve, Group}; use halo2_base::gates::builder::GateThreadBuilder; +use halo2_base::utils::modulus; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{modulus, CurveAffineExt}, + utils::CurveAffineExt, AssignedValue, Context, }; use itertools::Itertools; @@ -224,14 +225,9 @@ pub fn ec_sub_unequal>( let (P, Q) = check_points_are_unequal(chip, ctx, P, Q, is_strict); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.add_no_carry(ctx, Q.y, &P.y); + let sy = chip.add_no_carry(ctx, Q.y, &P.y); - let lambda = chip.neg_divide_unsafe(ctx, &dy, &dx); - - // (x_2 - x_1) * lambda + y_2 + y_1 = 0 (mod p) - let lambda_dx = chip.mul_no_carry(ctx, &lambda, dx); - let lambda_dx_plus_dy = chip.add_no_carry(ctx, lambda_dx, dy); - chip.check_carry_mod_to_zero(ctx, lambda_dx_plus_dy); + let lambda = chip.neg_divide_unsafe(ctx, sy, dx); // x_3 = lambda^2 - x_1 - x_2 (mod p) let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); @@ -259,7 +255,7 @@ pub fn ec_sub_strict>( where FC: Selectable, { - let P = P.into(); + let mut P = P.into(); let Q = Q.into(); // Compute curr_point - start_point, allowing for output to be identity point let x_is_eq = chip.is_equal(ctx, P.x(), Q.x()); @@ -268,6 +264,17 @@ where // we ONLY allow x_is_eq = true if y_is_eq is also true; this constrains P != -Q ctx.constrain_equal(&x_is_eq, &is_identity); + // P.x = Q.x and P.y = Q.y + // in ec_sub_unequal it will try to do -(P.y + Q.y) / (P.x - Q.x) = -2P.y / 0 + // this will cause divide_unsafe to panic when P.y != 0 + // to avoid this, we load a random pair of points and replace P with it *only if* `is_identity == true` + // we don't even check (rand_x, rand_y) is on the curve, since we don't care about the output + let mut rng = ChaCha20Rng::from_entropy(); + let [rand_x, rand_y] = [(); 2].map(|_| FC::FieldType::random(&mut rng)); + let [rand_x, rand_y] = [rand_x, rand_y].map(|x| chip.load_private(ctx, x)); + let rand_pt = EcPoint::new(rand_x, rand_y); + P = ec_select(chip, ctx, rand_pt, P, is_identity); + let out = ec_sub_unequal(chip, ctx, P, Q, false); let zero = chip.load_constant(ctx, FC::FieldType::zero()); ec_select(chip, ctx, EcPoint::new(zero.clone(), zero), out, is_identity) @@ -469,26 +476,26 @@ where /// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` /// /// # Assumptions -/// - `P` is not the point at infinity -/// - `scalar > 0` -/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) -/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) +/// - `window_bits != 0` +/// - The order of `P` is at least `2^{window_bits}` (in particular, `P` is not the point at infinity) +/// - The curve has no points of order 2. /// - `scalar_i < 2^{max_bits} for all i` /// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` -pub fn scalar_multiply( +pub fn scalar_multiply( chip: &FC, ctx: &mut Context, P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where FC: FieldChip + Selectable, + C: CurveAffineExt, { assert!(!scalar.is_empty()); assert!((max_bits as u64) <= modulus::().bits()); + assert!(window_bits != 0); let total_bits = max_bits * scalar.len(); let num_windows = (total_bits + window_bits - 1) / window_bits; @@ -506,7 +513,7 @@ where // is_started[idx] holds whether there is a 1 in bits with index at least (rounded_bitlen - idx) let mut is_started = Vec::with_capacity(rounded_bitlen); is_started.resize(rounded_bitlen - total_bits + 1, zero_cell); - for idx in 1..total_bits { + for idx in 1..=total_bits { let or = chip.gate().or(ctx, *is_started.last().unwrap(), rounded_bits[total_bits - idx]); is_started.push(or); } @@ -523,22 +530,23 @@ where is_zero_window.push(is_zero); } - // cached_points[idx] stores idx * P, with cached_points[0] = P + let any_point = load_random_point::(chip, ctx); + // cached_points[idx] stores idx * P, with cached_points[0] = any_point let cache_size = 1usize << window_bits; let mut cached_points = Vec::with_capacity(cache_size); - cached_points.push(P.clone()); + cached_points.push(any_point); cached_points.push(P.clone()); for idx in 2..cache_size { if idx == 2 { let double = ec_double(chip, ctx, &P); cached_points.push(double); } else { - let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, !scalar_is_safe); + let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, false); cached_points.push(new_point); } } - // if all the starting window bits are 0, get start_point = P + // if all the starting window bits are 0, get start_point = any_point let mut curr_point = ec_select_from_bits( chip, ctx, @@ -558,13 +566,17 @@ where &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, !scalar_is_safe); + // if is_zero_window[idx] = true, add_point = any_point. We only need any_point to avoid divide by zero in add_unequal + // if is_zero_window = true and is_started = false, then mult_point = 2^window_bits * any_point. Since window_bits != 0, we have mult_point != +- any_point + let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, true); let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]); curr_point = ec_select(chip, ctx, is_started_point, add_point, is_started[window_bits * idx]); } - curr_point + // if at the end, return identity point (0,0) if still not started + let zero = chip.load_constant(ctx, FC::FieldType::zero()); + ec_select(chip, ctx, curr_point, EcPoint::new(zero.clone(), zero), *is_started.last().unwrap()) } /// Checks that `P` is indeed a point on the elliptic curve `C`. @@ -1007,24 +1019,18 @@ where } /// See [`scalar_multiply`] for more details. - pub fn scalar_mult( + pub fn scalar_mult( &self, ctx: &mut Context, P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, - ) -> EcPoint { - scalar_multiply::( - self.field_chip, - ctx, - P, - scalar, - max_bits, - window_bits, - scalar_is_safe, - ) + ) -> EcPoint + where + C: CurveAffineExt, + { + scalar_multiply::(self.field_chip, ctx, P, scalar, max_bits, window_bits) } // default for most purposes @@ -1038,14 +1044,13 @@ where ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, FC: Selectable, { // window_bits = 4 is optimal from empirical observations self.variable_base_msm_in::(thread_pool, P, scalars, max_bits, 4, 0) } - // TODO: put a check in place that scalar is < modulus of C::Scalar + // TODO: add asserts to validate input assumptions described in docs pub fn variable_base_msm_in( &self, builder: &mut GateThreadBuilder, @@ -1057,7 +1062,6 @@ where ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, FC: Selectable, { #[cfg(feature = "display")] @@ -1104,7 +1108,6 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where C: CurveAffineExt, @@ -1117,7 +1120,6 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar, max_bits, window_bits, - scalar_is_safe, ) } diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index 27d4c1c6..45e251f3 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -1,5 +1,4 @@ #![allow(non_snake_case)] -use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, dev::MockProver, @@ -21,21 +20,10 @@ use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; use halo2_base::Context; use rand::random; use rand_core::OsRng; -use serde::{Deserialize, Serialize}; use std::fs::File; use test_case::test_case; -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct CircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, -} +use super::CircuitParams; fn ecdsa_test( ctx: &mut Context, diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index cdd58dd8..803ac232 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1,2 +1,162 @@ +#![allow(non_snake_case)] +use std::fs::File; + +use ff::Field; +use group::Curve; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + halo2_proofs::{ + dev::MockProver, + halo2curves::{ + bn256::Fr, + secp256k1::{Fq, Secp256k1Affine}, + }, + }, + utils::{biguint_to_fe, fe_to_biguint, BigPrimeField}, + Context, +}; +use num_bigint::BigUint; +use rand_core::OsRng; +use serde::{Deserialize, Serialize}; + +use crate::{ + ecc::EccChip, + fields::{FieldChip, FpStrategy}, + secp256k1::{FpChip, FqChip}, +}; + pub mod ecdsa; pub mod ecdsa_tests; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct CircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, +} + +fn sm_test( + ctx: &mut Context, + params: CircuitParams, + base: Secp256k1Affine, + scalar: Fq, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::>::new(&fp_chip); + + let s = fq_chip.load_private(ctx, scalar); + let P = ecc_chip.assign_point(ctx, base); + + let sm = ecc_chip.scalar_mult::( + ctx, + P, + s.limbs().to_vec(), + fq_chip.limb_bits, + window_bits, + ); + + let sm_answer = (base * scalar).to_affine(); + + let sm_x = sm.x.value(); + let sm_y = sm.y.value(); + assert_eq!(sm_x, fe_to_biguint(&sm_answer.x)); + assert_eq!(sm_y, fe_to_biguint(&sm_answer.y)); +} + +fn sm_circuit( + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + base: Secp256k1Affine, + scalar: Fq, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = GateThreadBuilder::new(stage == CircuitBuilderStage::Prover); + + sm_test(builder.main(0), params, base, scalar, 4); + + match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + } +} + +#[test] +fn test_secp_sm_random() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = sm_circuit( + params, + CircuitBuilderStage::Mock, + None, + Secp256k1Affine::random(OsRng), + Fq::random(OsRng), + ); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_secp_sm_minus_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let mut s = -Fq::one(); + let mut n = fe_to_biguint(&s); + loop { + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + if &n % BigUint::from(2usize) == BigUint::from(0usize) { + break; + } + n /= 2usize; + s = biguint_to_fe(&n); + } +} + +#[test] +fn test_secp_sm_0_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let s = Fq::zero(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + + let s = Fq::one(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +}