From da7d5b1671300d4988a1dc30e6e58869ec337fb5 Mon Sep 17 00:00:00 2001 From: yulliakot Date: Fri, 19 May 2023 23:57:48 -0500 Subject: [PATCH 01/12] More ecdsa tests --- halo2-ecc/Cargo.toml | 1 + halo2-ecc/src/ecc/tests.rs | 1 + halo2-ecc/src/secp256k1/tests/ecdsa.rs | 7 +- halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 240 +++++++++++++++++++ halo2-ecc/src/secp256k1/tests/mod.rs | 1 + 5 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index d5c9d056..2b03e1cb 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -14,6 +14,7 @@ rand_chacha = "0.3.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" rayon = "1.6.1" +test-case = "3.1.0" # arithmetic ff = "0.12" diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index fb9d7abf..ad7687d3 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -95,3 +95,4 @@ fn plot_ecc() { halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); } + diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index 739bffc7..4518d50e 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -37,6 +37,7 @@ use std::io::BufReader; use std::io::Write; use std::{fs, io::BufRead}; + #[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct CircuitParams { strategy: FpStrategy, @@ -72,6 +73,7 @@ fn ecdsa_test( &fp_chip, ctx, &pk, &r, &s, &m, 4, 4, ); assert_eq!(res.value(), &F::one()); + } fn random_ecdsa_circuit( @@ -94,11 +96,13 @@ fn random_ecdsa_circuit( let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); let x = r_point.x(); let x_bigint = fe_to_biguint(x); + let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); + let circuit = match stage { CircuitBuilderStage::Mock => { @@ -127,6 +131,7 @@ fn test_secp256k1_ecdsa() { MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } + #[test] fn bench_secp256k1_ecdsa() -> Result<(), Box> { let mut rng = OsRng; diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs new file mode 100644 index 00000000..27565d6a --- /dev/null +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -0,0 +1,240 @@ +#![allow(non_snake_case)] +use crate::fields::FpStrategy; +use crate::halo2_proofs::{ + arithmetic::CurveAffine, + dev::MockProver, + halo2curves::bn256::Fr, + halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, +}; +use crate::secp256k1::{FpChip, FqChip}; +use crate::{ + ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, + fields::{FieldChip, PrimeField}, +}; +use ark_std::{end_timer, start_timer}; +use halo2_base::gates::builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, +}; + +use halo2_base::gates::RangeChip; +use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; +use halo2_base::Context; +use rand_core::OsRng; +use serde::{Deserialize, Serialize}; +use std::fs::File; + + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct CircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, +} + +fn ecdsa_test( + ctx: &mut Context, + params: CircuitParams, + r: Fq, + s: Fq, + msghash: Fq, + pk: Secp256k1Affine, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + + let [m, r, s] = + [msghash, r, s].map(|x| fq_chip.load_private(ctx, FqChip::::fe_to_witness(&x))); + + let ecc_chip = EccChip::>::new(&fp_chip); + let pk = ecc_chip.load_private(ctx, (pk.x, pk.y)); + // test ECDSA + let res = ecdsa_verify_no_pubkey_check::( + &fp_chip, ctx, &pk, &r, &s, &m, 4, 4, + ); + assert_eq!(res.value(), &F::one()); +} + + +fn random_parameters_ecdsa() -> ( + Fq, + Fq, + Fq, + Secp256k1Affine, + ) { + let sk = ::ScalarExt::random(OsRng); + let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msg_hash = ::ScalarExt::random(OsRng); + + let k = ::ScalarExt::random(OsRng); + let k_inv = k.invert().unwrap(); + + let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); + let x = r_point.x(); + let x_bigint = fe_to_biguint(x); + + + let r = biguint_to_fe::(&(x_bigint % modulus::())); + let s = k_inv * (msg_hash + (r * sk)); + + return ( r, s, msg_hash, pubkey) +} + +fn custom_parameters_ecdsa( + sk: u64, + msg_hash: u64, + k: u64, +) -> ( + Fq, + Fq, + Fq, + Secp256k1Affine, + ){ + let sk = ::ScalarExt::from(sk); + let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msg_hash = ::ScalarExt::from(msg_hash); + + let k = ::ScalarExt::from(k); + let k_inv = k.invert().unwrap(); + + let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); + let x = r_point.x(); + let x_bigint = fe_to_biguint(x); + + + let r = biguint_to_fe::(&(x_bigint % modulus::())); + let s = k_inv * (msg_hash + (r * sk)); + + return (r, s, msg_hash, pubkey) +} + + +fn ecdsa_circuit(r: Fq, s: Fq, msg_hash: Fq, pubkey: Secp256k1Affine, + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); + + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + + + +use rand::random; +#[cfg(test)] + #[test] + #[should_panic(expected = "assertion failed: `(left == right)`")] + fn test_ecdsa_msg_hash_zero() + { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(random::(),0, random::()); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + } + + #[cfg(test)] + #[test] + #[should_panic(expected = "assertion failed: `(left == right)`")] + fn test_ecdsa_private_key_zero() + { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(0, random::(), random::()); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + } + + + +#[cfg(test)] +#[test] +fn test_ecdsa_random_valid_inputs() + { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = random_parameters_ecdsa(); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + + + + +use test_case::test_case; +#[cfg(test)] +#[test_case(1, 1, 1; "")] +fn test_ecdsa_custom_valid_inputs(sk: u64,msg_hash: u64, k: u64,) + { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + + + +#[cfg(test)] +#[test_case(1, 1, 1; "")] +fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64,msg_hash: u64, k: u64,) + { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); + let s = -s; + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index ecc8b287..1de89040 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1 +1,2 @@ pub mod ecdsa; +pub mod unit_tests; From 150b88f6e69051477ff4e8422fe1e308e8d13676 Mon Sep 17 00:00:00 2001 From: yuliakot <93175658+yuliakot@users.noreply.github.com> Date: Sat, 20 May 2023 00:18:37 -0500 Subject: [PATCH 02/12] Update mod.rs --- halo2-ecc/src/secp256k1/tests/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index 1de89040..cdd58dd8 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1,2 +1,2 @@ pub mod ecdsa; -pub mod unit_tests; +pub mod ecdsa_tests; From e9386b9fbaba2c9c7aeea27c6ce914c233027cc5 Mon Sep 17 00:00:00 2001 From: yuliakot <93175658+yuliakot@users.noreply.github.com> Date: Sat, 20 May 2023 00:20:23 -0500 Subject: [PATCH 03/12] Update tests.rs --- halo2-ecc/src/ecc/tests.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index ad7687d3..fb9d7abf 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -95,4 +95,3 @@ fn plot_ecc() { halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); } - From c42a2efafc9c72d3f8183c18ca9131ef0988816a Mon Sep 17 00:00:00 2001 From: yuliakot <93175658+yuliakot@users.noreply.github.com> Date: Sat, 20 May 2023 00:22:51 -0500 Subject: [PATCH 04/12] Update ecdsa.rs --- halo2-ecc/src/secp256k1/tests/ecdsa.rs | 6 ------ 1 file changed, 6 deletions(-) diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index 4518d50e..834b1f1e 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -37,7 +37,6 @@ use std::io::BufReader; use std::io::Write; use std::{fs, io::BufRead}; - #[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct CircuitParams { strategy: FpStrategy, @@ -73,7 +72,6 @@ fn ecdsa_test( &fp_chip, ctx, &pk, &r, &s, &m, 4, 4, ); assert_eq!(res.value(), &F::one()); - } fn random_ecdsa_circuit( @@ -96,13 +94,10 @@ fn random_ecdsa_circuit( let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); let x = r_point.x(); let x_bigint = fe_to_biguint(x); - let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); - let circuit = match stage { CircuitBuilderStage::Mock => { @@ -131,7 +126,6 @@ fn test_secp256k1_ecdsa() { MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } - #[test] fn bench_secp256k1_ecdsa() -> Result<(), Box> { let mut rng = OsRng; From 33fcccc832b25ded6ddca9712c681a1fcb25d065 Mon Sep 17 00:00:00 2001 From: yuliakot <93175658+yuliakot@users.noreply.github.com> Date: Sat, 20 May 2023 00:23:18 -0500 Subject: [PATCH 05/12] Update ecdsa.rs --- halo2-ecc/src/secp256k1/tests/ecdsa.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index 834b1f1e..f6533ba5 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -96,6 +96,7 @@ fn random_ecdsa_circuit( let x_bigint = fe_to_biguint(x); let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); From 44bc7443db48336de7f06a3784004000801ff08c Mon Sep 17 00:00:00 2001 From: yuliakot <93175658+yuliakot@users.noreply.github.com> Date: Sat, 20 May 2023 00:24:18 -0500 Subject: [PATCH 06/12] Update ecdsa.rs --- halo2-ecc/src/secp256k1/tests/ecdsa.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index f6533ba5..739bffc7 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -96,7 +96,7 @@ fn random_ecdsa_circuit( let x_bigint = fe_to_biguint(x); let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); From dc5f61f71f90c341cad283b8cf2549d008d1cec7 Mon Sep 17 00:00:00 2001 From: yuliakot <93175658+yuliakot@users.noreply.github.com> Date: Mon, 29 May 2023 16:06:09 -0500 Subject: [PATCH 07/12] msm tests --- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 197 ++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 halo2-ecc/src/bn254/tests/msm_sum_infinity.rs diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs new file mode 100644 index 00000000..257bea5c --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -0,0 +1,197 @@ +use crate::fields::FpStrategy; +use ff::{Field, PrimeField}; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + utils::fs::gen_srs, +}; +use rand_core::OsRng; +use std::{ + fs::{self, File}, + io::{BufRead, BufReader}, +}; + +use super::*; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct MSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + window_bits: usize, +} + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); + + let msm = ecc_chip.variable_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(),Fr::one(),-Fr::one()-Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + + +#[test] +fn test_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(),Fr::one(),-Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + + +#[test] +fn test_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(),Fr::one(),Fr::one(),-Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + + +#[test] +fn test_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![generator_point, generator_point, generator_point, (generator_point + generator_point + generator_point).to_affine()]; + let scalars = vec![Fr::one(),Fr::one(),Fr::one(),-Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(),-Fr::one(),Fr::one(),Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} \ No newline at end of file From 9373b7cb7485b21c8699c1d099f21d6575a7e985 Mon Sep 17 00:00:00 2001 From: yuliakot <93175658+yuliakot@users.noreply.github.com> Date: Mon, 29 May 2023 16:07:22 -0500 Subject: [PATCH 08/12] Update mod.rs --- halo2-ecc/src/bn254/tests/mod.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index b373d51e..c4ffb393 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -25,3 +25,4 @@ pub mod ec_add; pub mod fixed_base_msm; pub mod msm; pub mod pairing; +pub mod msm_sum_infinity; From d03805551c7c50186dff8aa62cc4f8b1986b7ec9 Mon Sep 17 00:00:00 2001 From: yuliakot <93175658+yuliakot@users.noreply.github.com> Date: Mon, 29 May 2023 16:10:05 -0500 Subject: [PATCH 09/12] Update msm_sum_infinity.rs --- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs index 257bea5c..eb82c82c 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -153,7 +153,7 @@ fn test_msm3() { params.batch_size = 4; let random_point = G1Affine::random(OsRng); - let bases = vec![random_point, random_point, (random_point + random_point + random_point).to_affine()]; + let bases = vec![random_point, random_point, random_point, (random_point + random_point + random_point).to_affine()]; let scalars = vec![Fr::one(),Fr::one(),Fr::one(),-Fr::one()]; let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); @@ -194,4 +194,4 @@ fn test_msm5() { let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); -} \ No newline at end of file +} From a89f672d1a4f152d7e33caf313bd709da35f0eb0 Mon Sep 17 00:00:00 2001 From: Jonathan Wang Date: Mon, 29 May 2023 15:53:37 -0700 Subject: [PATCH 10/12] fix: ec_sub_strict was panicing when output is identity * affects the MSM functions: right now if the answer is identity, there will be a panic due to divide by 0 instead of just returning 0 * there could be a more optimal solution, but due to the traits for EccChip, we just generate a random point solely to avoid divide by 0 in the case of identity point --- .github/workflows/ci.yml | 46 ++ halo2-base/Cargo.toml | 3 + halo2-base/README.md | 590 +++++++++++++++ .../gates/tests/prop_test.txt | 11 + halo2-base/src/gates/builder.rs | 3 + halo2-base/src/gates/builder/parallelize.rs | 38 + halo2-base/src/gates/flex_gate.rs | 22 +- halo2-base/src/gates/range.rs | 40 +- halo2-base/src/gates/tests/README.md | 9 + halo2-base/src/gates/tests/flex_gate_tests.rs | 266 +++++++ .../src/gates/tests/idx_to_indicator.rs | 5 + halo2-base/src/gates/tests/mod.rs | 11 + halo2-base/src/gates/tests/neg_prop_tests.rs | 398 ++++++++++ halo2-base/src/gates/tests/pos_prop_tests.rs | 326 +++++++++ .../src/gates/tests/range_gate_tests.rs | 155 ++++ .../src/gates/tests/test_ground_truths.rs | 190 +++++ halo2-base/src/lib.rs | 1 + halo2-base/src/utils.rs | 155 ++-- halo2-ecc/benches/fp_mul.rs | 2 +- halo2-ecc/benches/msm.rs | 6 +- halo2-ecc/src/bigint/add_no_carry.rs | 30 +- halo2-ecc/src/bigint/big_is_equal.rs | 48 +- halo2-ecc/src/bigint/big_is_zero.rs | 43 +- halo2-ecc/src/bigint/big_less_than.rs | 8 +- halo2-ecc/src/bigint/carry_mod.rs | 46 +- .../src/bigint/check_carry_mod_to_zero.rs | 18 +- halo2-ecc/src/bigint/mod.rs | 154 +++- halo2-ecc/src/bigint/mul_no_carry.rs | 21 +- halo2-ecc/src/bigint/negative.rs | 2 +- .../src/bigint/scalar_mul_and_add_no_carry.rs | 30 +- halo2-ecc/src/bigint/scalar_mul_no_carry.rs | 16 +- halo2-ecc/src/bigint/select.rs | 25 +- halo2-ecc/src/bigint/select_by_indicator.rs | 24 +- halo2-ecc/src/bigint/sub.rs | 39 +- halo2-ecc/src/bigint/sub_no_carry.rs | 26 +- halo2-ecc/src/bn254/final_exp.rs | 174 ++--- halo2-ecc/src/bn254/mod.rs | 11 +- halo2-ecc/src/bn254/pairing.rs | 170 +++-- halo2-ecc/src/bn254/tests/ec_add.rs | 9 +- halo2-ecc/src/bn254/tests/fixed_base_msm.rs | 28 +- halo2-ecc/src/bn254/tests/msm.rs | 14 +- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 59 +- halo2-ecc/src/bn254/tests/pairing.rs | 9 +- halo2-ecc/src/ecc/ecdsa.rs | 88 +-- halo2-ecc/src/ecc/fixed_base.rs | 273 ++----- halo2-ecc/src/ecc/mod.rs | 678 ++++++++++++------ halo2-ecc/src/ecc/pippenger.rs | 214 +++--- halo2-ecc/src/ecc/tests.rs | 24 +- halo2-ecc/src/fields/fp.rs | 307 +++++--- halo2-ecc/src/fields/fp12.rs | 426 +++-------- halo2-ecc/src/fields/fp2.rs | 372 ++-------- halo2-ecc/src/fields/mod.rs | 274 ++++--- halo2-ecc/src/fields/tests/fp/assert_eq.rs | 11 +- halo2-ecc/src/fields/tests/fp/mod.rs | 13 +- halo2-ecc/src/fields/tests/fp12/mod.rs | 8 +- halo2-ecc/src/fields/vector.rs | 495 +++++++++++++ halo2-ecc/src/secp256k1/tests/ecdsa.rs | 7 +- halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 137 ++-- .../zkevm-keccak/src/keccak_packed_multi.rs | 7 +- hashes/zkevm-keccak/src/util.rs | 8 +- .../src/util/constraint_builder.rs | 2 +- hashes/zkevm-keccak/src/util/eth_types.rs | 4 +- 62 files changed, 4556 insertions(+), 2073 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 halo2-base/README.md create mode 100644 halo2-base/proptest-regressions/gates/tests/prop_test.txt create mode 100644 halo2-base/src/gates/builder/parallelize.rs create mode 100644 halo2-base/src/gates/tests/README.md create mode 100644 halo2-base/src/gates/tests/flex_gate_tests.rs create mode 100644 halo2-base/src/gates/tests/neg_prop_tests.rs create mode 100644 halo2-base/src/gates/tests/pos_prop_tests.rs create mode 100644 halo2-base/src/gates/tests/range_gate_tests.rs create mode 100644 halo2-base/src/gates/tests/test_ground_truths.rs create mode 100644 halo2-ecc/src/fields/vector.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..4fedf24b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,46 @@ +name: Tests + +on: + push: + branches: ["main", "release-0.3.0"] + pull_request: + branches: ["main"] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Build + run: cargo build --verbose + - name: Run halo2-base tests + run: | + cd halo2-base + cargo test -- --test-threads=1 + cd .. + - name: Run halo2-ecc tests MockProver + run: | + cd halo2-ecc + cargo test -- --test-threads=1 test_fp + cargo test -- test_ecc + cargo test -- test_secp256k1_ecdsa + cargo test -- test_ecdsa + cargo test -- test_ec_add + cargo test -- test_fixed_base_msm + cargo test -- test_msm + cargo test -- test_pairing + cd .. + - name: Run halo2-ecc tests real prover + run: | + cd halo2-ecc + cargo test --release -- test_fp_assert_eq + cargo test --release -- --nocapture bench_secp256k1_ecdsa + cargo test --release -- --nocapture bench_ec_add + cargo test --release -- --nocapture bench_fixed_base_msm + cargo test --release -- --nocapture bench_msm + cargo test --release -- --nocapture bench_pairing + cd .. diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index bc560417..33799495 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -11,6 +11,7 @@ num-traits = "0.2" rand_chacha = "0.3" rustc-hash = "1.1" ff = "0.12" +rayon = "1.6.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" log = "0.4" @@ -34,6 +35,8 @@ pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" rayon = "1.6.1" +test-case = "3.1.0" +proptest = "1.1.0" # memory allocation [target.'cfg(not(target_env = "msvc"))'.dependencies] diff --git a/halo2-base/README.md b/halo2-base/README.md new file mode 100644 index 00000000..6b078ab9 --- /dev/null +++ b/halo2-base/README.md @@ -0,0 +1,590 @@ +# Halo2-base + +Halo2-base provides a streamlined frontend for interacting with the Halo2 API. It simplifies circuit programming to declaring constraints over a single advice and selector column and provides built-in circuit configuration and parellel proving and witness generation. + +Programmed circuit constraints are stored in `GateThreadBuilder` as a `Vec` of `Context`'s. Each `Context` can be interpreted as a "virtual column" which tracks witness values and constraints but does not assign them as cells within the Halo2 backend. Conceptually, one can think that at circuit generation time, the virtual columns are all concatenated into a **single** virtual column. This virtual column is then re-distributed into the minimal number of true `Column`s (aka Plonkish arithmetization columns) to fit within a user-specified number of rows. These true columns are then assigned into the Plonkish arithemization using the vanilla Halo2 backend. This has several benefits: + +- The user only needs to specify the desired number of rows. The rest of the circuit configuration process is done automatically because the optimal number of columns in the circuit can be calculated from the total number of cells in the `Context`s. This eliminates the need to manually assign circuit parameters at circuit creation time. +- In addition, this simplifies the process of testing the performance of different circuit configurations (different Plonkish arithmetization shapes) in the Halo2 backend, since the same virtual columns in the `Context` can be re-distributed into different Plonkish arithmetization tables. + +A user can also parallelize witness generation by specifying a function and a `Vec` of inputs to perform in parallel using `parallelize_in()` which creates a separate `Context` for each input that performs the specified function. These "virtual columns" are then computed in parallel during witness generation and combined back into a single column "virtual column" before cell assignment in the Halo2 backend. + +All assigned values in a circuit are assigned in the Halo2 backend by calling `synthesize()` in `GateCircuitBuilder` (or [`RangeCircuitBuilder`](#rangecircuitbuilder)) which in turn invokes `assign_all()` (or `assign_threads_in` if only doing witness generation) in `GateThreadBuilder` to assign the witness values tracked in a `Context` to their respective `Column` in the circuit within the Halo2 backend. + +Halo2-base also provides pre-built [Chips](https://zcash.github.io/halo2/concepts/chips.html) for common arithmetic operations in `GateChip` and range check arguments in `RangeChip`. Our `Chip` implementations differ slightly from ZCash's `Chip` implementations. In Zcash, the `Chip` struct stores knowledge about the `Config` and custom gates used. In halo2-base a `Chip` stores only functions while the interaction with the circuit's `Config` is hidden and done in `GateCircuitBuilder`. + +The structure of halo2-base is outlined as follows: + +- `builder.rs`: Contains `GateThreadBuilder`, `GateCircuitBuilder`, and `RangeCircuitBuilder` which implement the logic to provide different arithmetization configurations with different performance tradeoffs in the Halo2 backend. +- `lib.rs`: Defines the `QuantumCell`, `ContextCell`, `AssignedValue`, and `Context` types which track assigned values within a circuit across multiple columns and provide a streamlined interface to assign witness values directly to the advice column. +- `utils.rs`: Contains `BigPrimeField` and `ScalerField` traits which represent field elements within Halo2 and provides methods to decompose field elements into `u64` limbs and convert between field elements and `BigUint`. +- `flex_gate.rs`: Contains the implementation of `GateChip` and the `GateInstructions` trait which provide functions for basic arithmetic operations within Halo2. +- `range.rs:`: Implements `RangeChip` and the `RangeInstructions` trait which provide functions for performing range check and other lookup argument operations. + +This readme compliments the in-line documentation of halo2-base, providing an overview of `builder.rs` and `lib.rs`. + +
+ +## [**Context**](src/lib.rs) + +`Context` holds all information of an execution trace (circuit and its witness values). `Context` represents a "virtual column" that stores unassigned constraint information in the Halo2 backend. Storing the circuit information in a `Context` rather than assigning it directly to the Halo2 backend allows for the pre-computation of circuit parameters and preserves the underlying circuit information allowing for its rearrangement into multiple columns for parallelization in the Halo2 backend. + +During `synthesize()`, the advice values of all `Context`s are concatenated into a single "virtual column" that is split into multiple true `Column`s at `break_points` each representing a different sub-section of the "virtual column". During circuit synthesis, all cells are assigned to Halo2 `AssignedCell`s in a single `Region` within Halo2's backend. + +For parallel witness generation, multiple `Context`s are created for each parallel operation. After parallel witness generation, these `Context`'s are combined to form a single "virtual column" as above. Note that while the witness generation can be multi-threaded, the ordering of the contents in each `Context`, and the order of the `Context`s themselves, must be deterministic. + +```rust ignore +pub struct Context { + + witness_gen_only: bool, + + pub context_id: usize, + + pub advice: Vec>, + + pub cells_to_lookup: Vec>, + + pub zero_cell: Option>, + + pub selector: Vec, + + pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, + + pub constant_equality_constraints: Vec<(F, ContextCell)>, +} +``` + +`witness_gen_only` is set to `true` if we only care about witness generation and not about circuit constraints, otherwise it is set to false. This should **not** be set to `true` during mock proving or **key generation**. When this flag is `true`, we perform certain optimizations that are only valid when we don't care about constraints or selectors. + +A `Context` holds all equality and constant constraints as a `Vec` of `ContextCell` tuples representing the positions of the two cells to constrain. `advice` and`selector` store the respective column values of the `Context`'s which may represent the entire advice and selector column or a sub-section of the advice and selector column during parellel witness generation. `cells_to_lookup` tracks `AssignedValue`'s of cells to be looked up in a global lookup table, specifically for range checks, shared among all `Context`'s'. + +### [**ContextCell**](./src/lib.rs): + +`ContextCell` is a pointer to a specific cell within a `Context` identified by the Context's `context_id` and the cell's relative `offset` from the first cell of the advice column of the `Context`. + +```rust ignore +#[derive(Clone, Copy, Debug)] +pub struct ContextCell { + /// Identifier of the [Context] that this cell belongs to. + pub context_id: usize, + /// Relative offset of the cell within this [Context] advice column. + pub offset: usize, +} +``` + +### [**AssignedValue**](./src/lib.rs): + +`AssignedValue` represents a specific `Assigned` value assigned to a specific cell within a `Context` of a circuit referenced by a `ContextCell`. + +```rust ignore +pub struct AssignedValue { + pub value: Assigned, + + pub cell: Option, +} +``` + +### [**Assigned**](./src/plonk/assigned.rs) + +`Assigned` is a wrapper enum for values assigned to a cell within a circuit which stores the value as a fraction and marks it for batched inversion using [Montgomery's trick](https://zcash.github.io/halo2/background/fields.html#montgomerys-trick). Performing batched inversion allows for the computation of the inverse of all marked values with a single inversion operation. + +```rust ignore +pub enum Assigned { + /// The field element zero. + Zero, + /// A value that does not require inversion to evaluate. + Trivial(F), + /// A value stored as a fraction to enable batch inversion. + Rational(F, F), +} +``` + +
+ +## [**QuantumCell**](./src/lib.rs) + +`QuantumCell` is a helper enum that abstracts the scenarios in which a value is assigned to the advice column in Halo2-base. Without `QuantumCell` assigning existing or constant values to the advice column requires manually specifying the enforced constraints on top of assigning the value leading to bloated code. `QuantumCell` handles these technical operations, all a developer needs to do is specify which enum option in `QuantumCell` the value they are adding corresponds to. + +```rust ignore +pub enum QuantumCell { + + Existing(AssignedValue), + + Witness(F), + + WitnessFraction(Assigned), + + Constant(F), +} +``` + +QuantumCell contains the following enum variants. + +- **Existing**: + Assigns a value to the advice column that exists within the advice column. The value is an existing value from some previous part of your computation already in the advice column in the form of an `AssignedValue`. When you add an existing cell into the table a new cell will be assigned into the advice column with value equal to the existing value. An equality constraint will then be added between the new cell and the "existing" cell so the Verifier has a guarantee that these two cells are always equal. + + ```rust ignore + QuantumCell::Existing(acell) => { + self.advice.push(acell.value); + + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); + } + } + ``` + +- **Witness**: + Assigns an entirely new witness value into the advice column, such as a private input. When `assign_cell()` is called the value is wrapped in as an `Assigned::Trivial()` which marks it for exclusion from batch inversion. + ```rust ignore + QuantumCell::Witness(val) => { + self.advice.push(Assigned::Trivial(val)); + } + ``` +- **WitnessFraction**: + Assigns an entirely new witness value to the advice column. `WitnessFraction` exists for optimization purposes and accepts Assigned values wrapped in `Assigned::Rational()` marked for batch inverion. + ```rust ignore + QuantumCell::WitnessFraction(val) => { + self.advice.push(val); + } + ``` +- **Constant**: + A value that is a "known" constant. A "known" refers to known at circuit creation time to both the Prover and Verifier. When you assign a constant value there exists another secret "Fixed" column in the circuit constraint table whose values are fixed at circuit creation time. When you assign a Constant value, you are adding this value to the Fixed column, adding the value as a witness to the Advice column, and then imposing an equality constraint between the two corresponding cells in the Fixed and Advice columns. + +```rust ignore +QuantumCell::Constant(c) => { + self.advice.push(Assigned::Trivial(c)); + // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.constant_equality_constraints.push((c, new_cell)); + } +} +``` + +
+ +## [**GateThreadBuilder**](./src/gates/builder.rs) & [**GateCircuitBuilder**](./src/gates/builder.rs) + +`GateThreadBuilder` tracks the cell assignments of a circuit as an array of `Vec` of `Context`' where `threads[i]` contains all `Context`'s for phase `i`. Each array element corresponds to a distinct challenge phase of Halo2's proving system, each of which has its own unique set of rows and columns. + +```rust ignore +#[derive(Clone, Debug, Default)] +pub struct GateThreadBuilder { + /// Threads for each challenge phase + pub threads: [Vec>; MAX_PHASE], + /// Max number of threads + thread_count: usize, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + use_unknown: bool, +} +``` + +Once a `GateThreadBuilder` is created, gates may be assigned to a `Context` (or in the case of parallel witness generation multiple `Context`'s) within `threads`. Once the circuit is written `config()` is called to pre-compute the circuits size and set the circuit's environment variables. + +[**config()**](./src/gates/builder.rs) + +```rust ignore +pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let total_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) + .collect::>(); + // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) + // if this is too small, manual configuration will be needed + let num_advice_per_phase = total_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_lookup_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) + .collect::>(); + let num_lookup_advice_per_phase = total_lookup_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { + threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) + })) + .len(); + let num_fixed = (total_fixed + (1 << k) - 1) >> k; + + let params = FlexGateConfigParams { + strategy: GateStrategy::Vertical, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + }; + #[cfg(feature = "display")] + { + for phase in 0..MAX_PHASE { + if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { + println!( + "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", + phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], + ); + } + } + println!("Total {total_fixed} fixed cells"); + println!("Auto-calculated config params:\n {params:#?}"); + } + std::env::set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); + params +} +``` + +For circuit creation a `GateCircuitBuilder` is created by passing the `GateThreadBuilder` as an argument to `GateCircuitBuilder`'s `keygen`,`mock`, or `prover` functions. `GateCircuitBuilder` acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's`Circuit` Trait and calling into `GateThreadBuilder` `assign_all()` and `assign_threads_in()` functions to perform circuit assignment. + +**Note for developers:** We encourage you to always use [`RangeCircuitBuilder`](#rangecircuitbuilder) instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. + +```rust ignore +/// Vector of vectors tracking the thread break points across different halo2 phases +pub type MultiPhaseThreadBreakPoints = Vec; + +#[derive(Clone, Debug)] +pub struct GateCircuitBuilder { + /// The Thread Builder for the circuit + pub builder: RefCell>, + /// Break points for threads within the circuit + pub break_points: RefCell, +} + +impl Circuit for GateCircuitBuilder { + type Config = FlexGateConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the circuit without withnesses filled in. + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config]. + fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase: _, + num_fixed, + k, + } = serde_json::from_str(&std::env::var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + self.sub_synthesize(&config, &[], &[], &mut layouter); + Ok(()) + } +} +``` + +During circuit creation `synthesize()` is invoked which passes into `sub_synthesize()` a `FlexGateConfig` containing the actual circuits columns and a mutable reference to a `Layouter` from the Halo2 API which facilitates the final assignment of cells within a `Region` of a circuit in Halo2's backend. + +`GateCircuitBuilder` contains a list of breakpoints for each thread across all phases in and `GateThreadBuilder` itself. Both are wrapped in a `RefCell` allowing them to be borrowed mutably so the function performing circuit creation can take ownership of the `builder` and `break_points` can be recorded during circuit creation for later use. + +[**sub_synthesize()**](./src/gates/builder.rs) + +```rust ignore + pub fn sub_synthesize( + &self, + gate: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + layouter: &mut impl Layouter, + ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { + let mut first_pass = SKIP_FIRST_PASS; + let mut assigned_advices = HashMap::new(); + layouter + .assign_region( + || "GateCircuitBuilder generated circuit", + |mut region| { + if first_pass { + first_pass = false; + return Ok(()); + } + // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize + // If we are not performing witness generation only, we can skip the first pass and assign threads directly + if !self.builder.borrow().witness_gen_only { + // clone the builder so we can re-use the circuit for both vk and pk gen + let builder = self.builder.borrow().clone(); + for threads in builder.threads.iter().skip(1) { + assert!( + threads.is_empty(), + "GateCircuitBuilder only supports FirstPhase for now" + ); + } + let assignments = builder.assign_all( + gate, + lookup_advice, + q_lookup, + &mut region, + Default::default(), + ); + *self.break_points.borrow_mut() = assignments.break_points; + assigned_advices = assignments.assigned_advices; + } else { + // If we are only generating witness, we can skip the first pass and assign threads directly + let builder = self.builder.take(); + let break_points = self.break_points.take(); + for (phase, (threads, break_points)) in builder + .threads + .into_iter() + .zip(break_points.into_iter()) + .enumerate() + .take(1) + { + assign_threads_in( + phase, + threads, + gate, + lookup_advice.get(phase).unwrap_or(&vec![]), + &mut region, + break_points, + ); + } + } + Ok(()) + }, + ) + .unwrap(); + assigned_advices + } +``` + +Within `sub_synthesize()` `layouter`'s `assign_region()` function is invoked which yields a mutable reference to `Region`. `region` is used to assign cells within a contiguous region of the circuit represented in Halo2's proving system. + +If `witness_gen_only` is not set within the `builder` (for keygen, and mock proving) `sub_synthesize` takes ownership of the `builder`, and calls `assign_all()` to assign all cells within this context to a circuit in Halo2's backend. The resulting column breakpoints are recorded in `GateCircuitBuilder`'s `break_points` field. + +`assign_all()` iterates over each `Context` within a `phase` and assigns the values and constraints of the advice, selector, fixed, and lookup columns to the circuit using `region`. + +Breakpoints for the advice column are assigned sequentially. If, the `row_offset` of the cell value being currently assigned exceeds the maximum amount of rows allowed in a column a new column is created. + +It should be noted this process is only compatible with the first phase of Halo2's proving system as retrieving witness challenges in later phases requires more specialized witness generation during synthesis. Therefore, `assign_all()` must assert all elements in `threads` are unassigned excluding the first phase. + +[**assign_all()**](./src/gates/builder.rs) + +```rust ignore +pub fn assign_all( + &self, + config: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + region: &mut Region, + KeygenAssignments { + mut assigned_advices, + mut assigned_constants, + mut break_points + }: KeygenAssignments, + ) -> KeygenAssignments { + ... + for (phase, threads) in self.threads.iter().enumerate() { + let mut break_point = vec![]; + let mut gate_index = 0; + let mut row_offset = 0; + for ctx in threads { + let mut basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + assert_eq!(ctx.selector.len(), ctx.advice.len()); + + for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { + let column = basic_gate.value; + let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; + #[cfg(feature = "halo2-axiom")] + let cell = *region.assign_advice(column, row_offset, value).cell(); + #[cfg(not(feature = "halo2-axiom"))] + let cell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); + ... + +``` + +In the case a breakpoint falls on the overlap between two gates (such as chained addition of two cells) the cells the breakpoint falls on must be copied to the next column and a new equality constraint enforced between the value of the cell in the old column and the copied cell in the new column. This prevents the circuit from being undersconstratined and preserves the equality constraint from the overlapping gates. + +```rust ignore +if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { + break_point.push(row_offset); + row_offset = 0; + gate_index += 1; + +// when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety + basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + let column = basic_gate.value; + + #[cfg(feature = "halo2-axiom")] + { + let ncell = region.assign_advice(column, row_offset, value); + region.constrain_equal(ncell.cell(), &cell); + } + #[cfg(not(feature = "halo2-axiom"))] + { + let ncell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + region.constrain_equal(ncell, cell).unwrap(); + } +} + +``` + +If `witness_gen_only` is set, only witness generation is performed, and no copy constraints or selector values are considered. + +Witness generation can be parallelized by a user by calling `parallelize_in()` and specifying a function and a `Vec` of inputs to perform in parallel. `parallelize_in()` creates a separate `Context` for each input that performs the specified function and appends them to the `Vec` of `Context`'s of a particular phase. + +[**assign_threads_in()**](./src/gates/builder.rs) + +```rust ignore +pub fn assign_threads_in( + phase: usize, + threads: Vec>, + config: &FlexGateConfig, + lookup_advice: &[Column], + region: &mut Region, + break_points: ThreadBreakPoints, +) { + if config.basic_gates[phase].is_empty() { + assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); + return; + } + + let mut break_points = break_points.into_iter(); + let mut break_point = break_points.next(); + + let mut gate_index = 0; + let mut column = config.basic_gates[phase][gate_index].value; + let mut row_offset = 0; + + let mut lookup_offset = 0; + let mut lookup_advice = lookup_advice.iter(); + let mut lookup_column = lookup_advice.next(); + for ctx in threads { + // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns + if lookup_column.is_some() { + for advice in ctx.cells_to_lookup { + if lookup_offset >= config.max_rows { + lookup_offset = 0; + lookup_column = lookup_advice.next(); + } + // Assign the lookup advice values to the lookup_column + let value = advice.value; + let lookup_column = *lookup_column.unwrap(); + #[cfg(feature = "halo2-axiom")] + region.assign_advice(lookup_column, lookup_offset, Value::known(value)); + #[cfg(not(feature = "halo2-axiom"))] + region + .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) + .unwrap(); + + lookup_offset += 1; + } + } + // Assign advice values to the advice columns in each [Context] + for advice in ctx.advice { + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + + if break_point == Some(row_offset) { + break_point = break_points.next(); + row_offset = 0; + gate_index += 1; + column = config.basic_gates[phase][gate_index].value; + + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + } + + row_offset += 1; + } + } + +``` + +`sub_synthesize` iterates over all phases and calls `assign_threads_in()` for that phase. `assign_threads_in()` iterates over all `Context`s within that phase and assigns all lookup and advice values in the `Context`, creating a new advice column at every pre-computed "breakpoint" by incrementing `gate_index` and assigning `column` to a new `Column` found at `config.basic_gates[phase][gate_index].value`. + +## [**RangeCircuitBuilder**](./src/gates/builder.rs) + +`RangeCircuitBuilder` is a wrapper struct around `GateCircuitBuilder`. Like `GateCircuitBuilder` it acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's `Circuit` Trait. + +```rust ignore +#[derive(Clone, Debug)] +pub struct RangeCircuitBuilder(pub GateCircuitBuilder); + +impl Circuit for RangeCircuitBuilder { + type Config = RangeConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + let strategy = match strategy { + GateStrategy::Vertical => RangeStrategy::Vertical, + }; + let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); + RangeConfig::configure( + meta, + strategy, + &num_advice_per_phase, + &num_lookup_advice_per_phase, + num_fixed, + lookup_bits, + k, + ) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // only load lookup table if we are actually doing lookups + if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 + || !config.q_lookup.iter().all(|q| q.is_none()) + { + config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); + Ok(()) + } +} +``` + +`RangeCircuitBuilder` differs from `GateCircuitBuilder` in that it contains a `RangeConfig` instead of a `FlexGateConfig` as its `Config`. `RangeConfig` contains a `lookup` table needed to declare lookup arguments within Halo2's backend. When creating a circuit that uses lookup tables `GateThreadBuilder` must be wrapped with `RangeCircuitBuilder` instead of `GateCircuitBuilder` otherwise circuit synthesis will fail as a lookup table is not present within the Halo2 backend. + +**Note:** We encourage you to always use `RangeCircuitBuilder` instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. diff --git a/halo2-base/proptest-regressions/gates/tests/prop_test.txt b/halo2-base/proptest-regressions/gates/tests/prop_test.txt new file mode 100644 index 00000000..aa4e1000 --- /dev/null +++ b/halo2-base/proptest-regressions/gates/tests/prop_test.txt @@ -0,0 +1,11 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 8489bbcc3439950355c90ecbc92546a66e4b57eae0a3856e7a4ccb59bf74b4ce # shrinks to k = 1, len = 1, idx = 0, witness_vals = [0x0000000000000000000000000000000000000000000000000000000000000000] +cc b18c4f5e502fe36dbc2471f89a6ffb389beaf473b280e844936298ab1cf9b74e # shrinks to (k, len, idx, witness_vals) = (8, 2, 1, [0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000001]) +cc 4528fb02e7227f85116c2a16aef251b9c3b6d9c340ddb50b936c2140d7856cc4 # shrinks to inputs = ([], []) +cc 79bfe42c93b5962a38b2f831f1dd438d8381a24a6ce15bfb89a8562ce9af0a2d # shrinks to (k, len, idx, witness_vals) = (8, 62, 0, [0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000]) +cc d0e10a06108cb58995a8ae77a91b299fb6230e9e6220121c48f2488e5d199e82 # shrinks to input = (0x000000000000000000000000000000000000000000000000070a95cb0607bef9, 4096) diff --git a/halo2-base/src/gates/builder.rs b/halo2-base/src/gates/builder.rs index 9cd68db0..22c2ce93 100644 --- a/halo2-base/src/gates/builder.rs +++ b/halo2-base/src/gates/builder.rs @@ -17,6 +17,9 @@ use std::{ env::{set_var, var}, }; +mod parallelize; +pub use parallelize::*; + /// Vector of thread advice column break points pub type ThreadBreakPoints = Vec; /// Vector of vectors tracking the thread break points across different halo2 phases diff --git a/halo2-base/src/gates/builder/parallelize.rs b/halo2-base/src/gates/builder/parallelize.rs new file mode 100644 index 00000000..ab9171d5 --- /dev/null +++ b/halo2-base/src/gates/builder/parallelize.rs @@ -0,0 +1,38 @@ +use itertools::Itertools; +use rayon::prelude::*; + +use crate::{utils::ScalarField, Context}; + +use super::GateThreadBuilder; + +/// Utility function to parallelize an operation involving [`Context`]s in phase `phase`. +pub fn parallelize_in( + phase: usize, + builder: &mut GateThreadBuilder, + input: Vec, + f: FR, +) -> Vec +where + F: ScalarField, + T: Send, + R: Send, + FR: Fn(&mut Context, T) -> R + Send + Sync, +{ + let witness_gen_only = builder.witness_gen_only(); + // to prevent concurrency issues with context id, we generate all the ids first + let ctx_ids = input.iter().map(|_| builder.get_new_thread_id()).collect_vec(); + let (outputs, mut ctxs): (Vec<_>, Vec<_>) = input + .into_par_iter() + .zip(ctx_ids.into_par_iter()) + .map(|(input, ctx_id)| { + // create new context + let mut ctx = Context::new(witness_gen_only, ctx_id); + let output = f(&mut ctx, input); + (output, ctx) + }) + .unzip(); + // we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused + builder.threads[phase].append(&mut ctxs); + + outputs +} diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index ee4ebb69..1907521e 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -74,7 +74,7 @@ impl BasicGateConfig { /// Wrapper for [ConstraintSystem].create_gate(name, meta) creates a gate form [q * (a + b * c - out)]. /// * `meta`: [ConstraintSystem] used for the gate fn create_gate(&self, meta: &mut ConstraintSystem) { - meta.create_gate("1 column a * b + c = out", |meta| { + meta.create_gate("1 column a + b * c = out", |meta| { let q = meta.query_selector(self.q_enable); let a = meta.query_advice(self.value, Rotation::cur()); @@ -558,12 +558,16 @@ pub trait GateInstructions { /// Constrains and returns an indicator vector from a slice of boolean values, where `output[idx] = 1` iff idx = (the number represented by `bits` in binary little endian), otherwise `output[idx] = 0`. /// * `ctx`: [Context] to add the constraints to /// * `bits`: slice of [QuantumCell]'s that contains boolean values + /// + /// # Assumptions + /// * `bits` is non-empty fn bits_to_indicator( &self, ctx: &mut Context, bits: &[AssignedValue], ) -> Vec> { let k = bits.len(); + assert!(k > 0, "bits_to_indicator: bits must be non-empty"); // (inv_last_bit, last_bit) = (1, 0) if bits[k - 1] = 0 let (inv_last_bit, last_bit) = { @@ -759,19 +763,23 @@ pub trait GateInstructions { /// Performs and constrains Lagrange interpolation on `coords` and evaluates the resulting polynomial at `x`. /// - /// Given pairs `coords[i] = (x_i, y_i)`, let `f` be the unique degree `len(coords)` polynomial such that `f(x_i) = y_i` for all `i`. + /// Given pairs `coords[i] = (x_i, y_i)`, let `f` be the unique degree `len(coords) - 1` polynomial such that `f(x_i) = y_i` for all `i`. /// /// Returns: /// (f(x), Prod_i(x - x_i)) /// * `ctx`: [Context] to add the constraints to /// * `coords`: immutable reference to a slice of tuples of [AssignedValue]s representing the points to interpolate over such that `coords[i] = (x_i, y_i)` /// * `x`: x-coordinate of the point to evaluate `f` at + /// + /// # Assumptions + /// * `coords` is non-empty fn lagrange_and_eval( &self, ctx: &mut Context, coords: &[(AssignedValue, AssignedValue)], x: AssignedValue, ) -> (AssignedValue, AssignedValue) { + assert!(!coords.is_empty(), "coords should not be empty"); let mut z = self.sub(ctx, Existing(x), Existing(coords[0].0)); for coord in coords.iter().skip(1) { let sub = self.sub(ctx, Existing(x), Existing(coord.0)); @@ -1100,20 +1108,14 @@ impl GateInstructions for GateChip { /// /// Assumes `range_bits >= number of bits in a`. /// * `a`: [QuantumCell] of the value to convert - /// * `range_bits`: range of bits needed to represent `a` + /// * `range_bits`: range of bits needed to represent `a`. Assumes `range_bits > 0`. fn num_to_bits( &self, ctx: &mut Context, a: AssignedValue, range_bits: usize, ) -> Vec> { - let a_bytes = a.value().to_repr(); - let bits = a_bytes - .as_ref() - .iter() - .flat_map(|byte| (0..8u32).map(|i| (*byte as u64 >> i) & 1)) - .map(|x| Witness(F::from(x))) - .take(range_bits); + let bits = a.value().to_u64_limbs(range_bits, 1).into_iter().map(|x| Witness(F::from(x))); let mut bit_cells = Vec::with_capacity(range_bits); let row_offset = ctx.advice.len(); diff --git a/halo2-base/src/gates/range.rs b/halo2-base/src/gates/range.rs index 20ebd57b..7a6b6173 100644 --- a/halo2-base/src/gates/range.rs +++ b/halo2-base/src/gates/range.rs @@ -259,7 +259,7 @@ pub trait RangeInstructions { num_bits: usize, ) -> AssignedValue; - /// Performs a range check that `a` has at most `bit_length(b)` and then constrains that `a` is in `[0,b)`. + /// Performs a range check that `a` has at most `ceil(bit_length(b) / lookup_bits) * lookup_bits` and then constrains that `a` is in `[0,b)`. /// /// Returns 1 if `a` < `b`, otherwise 0. /// @@ -278,12 +278,14 @@ pub trait RangeInstructions { self.is_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) } - /// Performs a range check that `a` has at most `bit_length(b)` and then constrains that `a` is in `[0,b)`. + /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is in `[0,b)`. /// /// Returns 1 if `a` < `b`, otherwise 0. /// /// * a: [AssignedValue] value to check /// * b: upper bound as [BigUint] value + /// + /// For the current implementation using [`is_less_than`], we require `ceil(b.bits() / lookup_bits) + 1 < F::NUM_BITS / lookup_bits` fn is_big_less_than_safe( &self, ctx: &mut Context, @@ -411,18 +413,16 @@ pub trait RangeInstructions { a: AssignedValue, limb_bits: usize, ) -> AssignedValue { - let a_v = a.value(); - let bit_v = { - let a = a_v.get_lower_32(); - F::from(a ^ 1 != 0) - }; + let a_big = fe_to_biguint(a.value()); + let bit_v = F::from(a_big.bit(0)); let two = self.gate().get_field_element(2u64); - let h_v = (*a_v - bit_v) * two.invert().unwrap(); - ctx.assign_region([Witness(bit_v), Witness(h_v), Constant(two), Existing(a)], [0]); + let h_v = F::from_bytes_le(&(a_big >> 1usize).to_bytes_le()); + ctx.assign_region([Witness(bit_v), Witness(h_v), Constant(two), Existing(a)], [0]); let half = ctx.get(-3); - self.range_check(ctx, half, limb_bits - 1); let bit = ctx.get(-4); + + self.range_check(ctx, half, limb_bits - 1); self.gate().assert_bit(ctx, bit); bit } @@ -441,8 +441,8 @@ pub struct RangeChip { pub gate: GateChip, /// Defines the number of bits represented in the lookup table [0,2lookup_bits). pub lookup_bits: usize, - /// [Vec] of 'limbs' represented as [QuantumCell] that divide the underlying scalar field element into sections smaller than lookup_bits. - /// * This allows range checks on field elements that are larger than the maximum value of the lookup table. + /// [Vec] of powers of `2 ** lookup_bits` represented as [QuantumCell::Constant]. + /// These are precomputed and cached as a performance optimization for later limb decompositions. We precompute up to the higher power that fits in `F`, which is `2 ** ((F::CAPACITY / lookup_bits) * lookup_bits)`. pub limb_bases: Vec>, } @@ -453,7 +453,7 @@ impl RangeChip { pub fn new(strategy: RangeStrategy, lookup_bits: usize) -> Self { let limb_base = F::from(1u64 << lookup_bits); let mut running_base = limb_base; - let num_bases = F::NUM_BITS as usize / lookup_bits; + let num_bases = F::CAPACITY as usize / lookup_bits; let mut limb_bases = Vec::with_capacity(num_bases + 1); limb_bases.extend([Constant(F::one()), Constant(running_base)]); for _ in 2..=num_bases { @@ -494,13 +494,16 @@ impl RangeInstructions for RangeChip { /// Checks and constrains that `a` lies in the range [0, 2range_bits). /// - /// This is done by decomposing `a` into `k` limbs, where `k = (range_bits + lookup_bits - 1) / lookup_bits`. + /// This is done by decomposing `a` into `k` limbs, where `k = ceil(range_bits / lookup_bits)`. /// Each limb is constrained to be within the range [0, 2lookup_bits). /// The limbs are then combined to form `a` again with the last limb having `rem_bits` number of bits. /// /// * `a`: [AssignedValue] value to be range checked /// * `range_bits`: number of bits in the range /// * `lookup_bits`: number of bits in the lookup table + /// + /// # Assumptions + /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { // the number of limbs let k = (range_bits + self.lookup_bits - 1) / self.lookup_bits; @@ -585,10 +588,13 @@ impl RangeInstructions for RangeChip { /// Constrains whether `a` is in `[0, b)`, and returns 1 if `a` < `b`, otherwise 0. /// - /// Assumes that`a` and `b` are known to have <= num_bits bits. /// * a: first [QuantumCell] to compare /// * b: second [QuantumCell] to compare /// * num_bits: number of bits to represent the values + /// + /// # Assumptions + /// * `a` and `b` are known to have `<= num_bits` bits. + /// * (`ceil(num_bits / lookup_bits) + 1) * lookup_bits <= F::CAPACITY` fn is_less_than( &self, ctx: &mut Context, @@ -601,6 +607,10 @@ impl RangeInstructions for RangeChip { let k = (num_bits + self.lookup_bits - 1) / self.lookup_bits; let padded_bits = k * self.lookup_bits; + debug_assert!( + padded_bits + self.lookup_bits <= F::CAPACITY as usize, + "num_bits is too large for this is_less_than implementation" + ); let pow_padded = self.gate.pow_of_two[padded_bits]; let shift_a_val = pow_padded + a.value(); diff --git a/halo2-base/src/gates/tests/README.md b/halo2-base/src/gates/tests/README.md new file mode 100644 index 00000000..24f34537 --- /dev/null +++ b/halo2-base/src/gates/tests/README.md @@ -0,0 +1,9 @@ +# Tests + +For tests that use `GateCircuitBuilder` or `RangeCircuitBuilder`, we currently must use environmental variables `FLEX_GATE_CONFIG` and `LOOKUP_BITS` to pass circuit configuration parameters to the `Circuit::configure` function. This is troublesome when Rust executes tests in parallel, so we to make sure all tests pass, run + +``` +cargo test -- --test-threads=1 +``` + +to force serial execution. diff --git a/halo2-base/src/gates/tests/flex_gate_tests.rs b/halo2-base/src/gates/tests/flex_gate_tests.rs new file mode 100644 index 00000000..b6d3e5ec --- /dev/null +++ b/halo2-base/src/gates/tests/flex_gate_tests.rs @@ -0,0 +1,266 @@ +use super::*; +use crate::halo2_proofs::dev::MockProver; +use crate::halo2_proofs::dev::VerifyFailure; +use crate::utils::ScalarField; +use crate::QuantumCell::Witness; +use crate::{ + gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder}, + flex_gate::{GateChip, GateInstructions}, + }, + QuantumCell, +}; +use test_case::test_case; + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "add(): 1 + 1 == 2")] +pub fn test_add(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.add(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub(): 1 - 1 == 0")] +pub fn test_sub(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.sub(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(Witness(Fr::from(1)) => -Fr::from(1) ; "neg(): 1 -> -1")] +pub fn test_neg(a: QuantumCell) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.neg(ctx, a); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "mul(): 1 * 1 == 1")] +pub fn test_mul(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.mul(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "mul_add(): 1 * 1 + 1 == 2")] +pub fn test_mul_add(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.mul_add(ctx, inputs[0], inputs[1], inputs[2]); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "mul_not(): 1 * 1 == 0")] +pub fn test_mul_not(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.mul_not(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(Fr::from(1) => Ok(()); "assert_bit(): 1 == bit")] +pub fn test_assert_bit(input: F) -> Result<(), Vec> { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([input])[0]; + chip.assert_bit(ctx, a); + // auto-tune circuit + builder.config(6, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + MockProver::run(6, &circuit, vec![]).unwrap().verify() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "div_unsafe(): 1 / 1 == 1")] +pub fn test_div_unsafe(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.div_unsafe(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from); "assert_is_const()")] +pub fn test_assert_is_const(inputs: &[F]) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([inputs[0]])[0]; + chip.assert_is_const(ctx, &a, &inputs[1]); + // auto-tune circuit + builder.config(6, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + MockProver::run(6, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => Fr::from(5) ; "inner_product(): 1 * 1 + ... + 1 * 1 == 5")] +pub fn test_inner_product(input: (Vec>, Vec>)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.inner_product(ctx, input.0, input.1); + *a.value() +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (Fr::from(5), Fr::from(1)); "inner_product_left_last(): 1 * 1 + ... + 1 * 1 == (5, 1)")] +pub fn test_inner_product_left_last( + input: (Vec>, Vec>), +) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.inner_product_left_last(ctx, input.0, input.1); + (*a.0.value(), *a.1.value()) +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => vec![Fr::one(), Fr::from(2), Fr::from(3), Fr::from(4), Fr::from(5)]; "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] +pub fn test_inner_product_with_sums( + input: (Vec>, Vec>), +) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.inner_product_with_sums(ctx, input.0, input.1); + a.into_iter().map(|x| *x.value()).collect() +} + +#[test_case((vec![(Fr::from(1), Witness(Fr::from(1)), Witness(Fr::from(1)))], Witness(Fr::from(1))) => Fr::from(2) ; "sum_product_with_coeff_and_var(): 1 * 1 + 1 == 2")] +pub fn test_sum_products_with_coeff_and_var( + input: (Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), +) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.sum_products_with_coeff_and_var(ctx, input.0, input.1); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "and(): 1 && 1 == 1")] +pub fn test_and(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.and(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(Witness(Fr::from(1)) => Fr::zero() ; "not(): !1 == 0")] +pub fn test_not(a: QuantumCell) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.not(ctx, a); + *a.value() +} + +#[test_case(&[2, 3, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "select(): 2 ? 3 : 1 == 2")] +pub fn test_select(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.select(ctx, inputs[0], inputs[1], inputs[2]); + *a.value() +} + +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "or_and(): 1 || 1 && 1 == 1")] +pub fn test_or_and(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.or_and(ctx, inputs[0], inputs[1], inputs[2]); + *a.value() +} + +#[test_case(Fr::zero() => vec![Fr::one(), Fr::zero()]; "bits_to_indicator(): 0 -> [1, 0]")] +pub fn test_bits_to_indicator(bits: F) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([bits])[0]; + let a = chip.bits_to_indicator(ctx, &[a]); + a.iter().map(|x| *x.value()).collect() +} + +#[test_case((Witness(Fr::zero()), 3) => vec![Fr::one(), Fr::zero(), Fr::zero()] ; "idx_to_indicator(): 0 -> [1, 0, 0]")] +pub fn test_idx_to_indicator(input: (QuantumCell, usize)) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.idx_to_indicator(ctx, input.0, input.1); + a.iter().map(|x| *x.value()).collect() +} + +#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_by_indicator(): [0, 1, 2] -> 1")] +pub fn test_select_by_indicator(input: (Vec>, QuantumCell)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); + let a = chip.select_by_indicator(ctx, input.0, a); + *a.value() +} + +#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_from_idx(): [0, 1, 2] -> 1")] +pub fn test_select_from_idx(input: (Vec>, QuantumCell)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); + let a = chip.select_by_indicator(ctx, input.0, a); + *a.value() +} + +#[test_case(Fr::zero() => Fr::from(1) ; "is_zero(): 0 -> 1")] +pub fn test_is_zero(x: F) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([x])[0]; + let a = chip.is_zero(ctx, a); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::one() ; "is_equal(): 1 == 1")] +pub fn test_is_equal(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.is_equal(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case((Fr::from(6u64), 3) => vec![Fr::zero(), Fr::one(), Fr::one()] ; "num_to_bits(): 6")] +pub fn test_num_to_bits(input: (F, usize)) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([input.0])[0]; + let a = chip.num_to_bits(ctx, a, input.1); + a.iter().map(|x| *x.value()).collect() +} + +#[test_case(&[0, 1, 2].map(Fr::from) => (Fr::one(), Fr::from(2)) ; "lagrange_eval(): constant fn")] +pub fn test_lagrange_eval(input: &[F]) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let input = ctx.assign_witnesses(input.iter().copied()); + let a = chip.lagrange_and_eval(ctx, &[(input[0], input[1])], input[2]); + (*a.0.value(), *a.1.value()) +} + +#[test_case(1 => Fr::one(); "inner_product_simple(): 1 -> 1")] +pub fn test_get_field_element(n: u64) -> F { + let chip = GateChip::default(); + chip.get_field_element(n) +} diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs index 4520b67c..4db68e3e 100644 --- a/halo2-base/src/gates/tests/idx_to_indicator.rs +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -110,5 +110,10 @@ fn test_idx_to_indicator() { test_idx_to_indicator_gen(8, 4); test_idx_to_indicator_gen(8, 10); test_idx_to_indicator_gen(8, 20); +} + +#[test] +#[ignore = "takes too long"] +fn test_idx_to_indicator_large() { test_idx_to_indicator_gen(11, 100); } diff --git a/halo2-base/src/gates/tests/mod.rs b/halo2-base/src/gates/tests/mod.rs index e7ebb386..a12adeba 100644 --- a/halo2-base/src/gates/tests/mod.rs +++ b/halo2-base/src/gates/tests/mod.rs @@ -1,3 +1,4 @@ +#![allow(clippy::type_complexity)] use crate::halo2_proofs::{ halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, @@ -12,10 +13,20 @@ use crate::halo2_proofs::{ }; use rand::rngs::OsRng; +#[cfg(test)] +mod flex_gate_tests; #[cfg(test)] mod general; #[cfg(test)] mod idx_to_indicator; +#[cfg(test)] +mod neg_prop_tests; +#[cfg(test)] +mod pos_prop_tests; +#[cfg(test)] +mod range_gate_tests; +#[cfg(test)] +mod test_ground_truths; /// helper function to generate a proof with real prover pub fn gen_proof( diff --git a/halo2-base/src/gates/tests/neg_prop_tests.rs b/halo2-base/src/gates/tests/neg_prop_tests.rs new file mode 100644 index 00000000..226a01f9 --- /dev/null +++ b/halo2-base/src/gates/tests/neg_prop_tests.rs @@ -0,0 +1,398 @@ +use std::env::set_var; + +use ff::Field; +use itertools::Itertools; +use num_bigint::BigUint; +use proptest::{collection::vec, prelude::*}; +use rand::rngs::OsRng; + +use crate::halo2_proofs::{ + dev::MockProver, + halo2curves::{bn256::Fr, FieldExt}, + plonk::Assigned, +}; +use crate::{ + gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, + range::{RangeChip, RangeInstructions}, + tests::{ + pos_prop_tests::{rand_bin_witness, rand_fr, rand_witness}, + test_ground_truths, + }, + GateChip, GateInstructions, + }, + utils::{biguint_to_fe, bit_length, fe_to_biguint, ScalarField}, + QuantumCell, + QuantumCell::Witness, +}; + +// Strategies for generating random witnesses +prop_compose! { + // length == 1 is just selecting [0] which should be covered in unit test + fn idx_to_indicator_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, idx_val in prop::sample::select(vec![Fr::zero(), Fr::one(), Fr::random(OsRng)]), len in 2usize..=max_size) + (k in Just(k), idx in 0..len, idx_val in Just(idx_val), len in Just(len), mut witness_vals in arb_indicator::(len)) + -> (usize, usize, usize, Vec) { + witness_vals[idx] = idx_val; + (k, len, idx, witness_vals) + } +} + +prop_compose! { + fn select_strat(k_bounds: (usize, usize)) + (k in k_bounds.0..=k_bounds.1, a in rand_witness(), b in rand_witness(), sel in rand_bin_witness(), rand_output in rand_fr()) + -> (usize, QuantumCell, QuantumCell, QuantumCell, Fr) { + (k, a, b, sel, rand_output) + } +} + +prop_compose! { + fn select_by_indicator_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec>, usize, Fr) { + (k, a, idx, rand_output) + } +} + +prop_compose! { + fn select_from_idx_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), cells in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec>, usize, Fr) { + (k, cells, idx, rand_output) + } +} + +prop_compose! { + fn inner_product_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in rand_fr()) + -> (usize, Vec>, Vec>, Fr) { + (k, a, b, rand_output) + } +} + +prop_compose! { + fn inner_product_left_last_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in (rand_fr(), rand_fr())) + -> (usize, Vec>, Vec>, (Fr, Fr)) { + (k, a, b, rand_output) + } +} + +prop_compose! { + pub fn range_check_strat(k_bounds: (usize, usize), max_range_bits: usize) + (k in k_bounds.0..=k_bounds.1, range_bits in 1usize..=max_range_bits) // lookup_bits must be less than k + (k in Just(k), range_bits in Just(range_bits), lookup_bits in 8..k, + rand_a in prop::sample::select(vec![ + biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) - 1usize)), + biguint_to_fe(&BigUint::from(2u64).pow(range_bits as u32)), + biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) + 1usize)), + Fr::random(OsRng) + ])) + -> (usize, usize, usize, Fr) { + (k, range_bits, lookup_bits, rand_a) + } +} + +prop_compose! { + fn is_less_than_safe_strat(k_bounds: (usize, usize)) + // compose strat to generate random rand fr in range + (b in any::().prop_filter("not zero", |&i| i != 0), k in k_bounds.0..=k_bounds.1) + (k in Just(k), b in Just(b), lookup_bits in k_bounds.0 - 1..k, rand_a in rand_fr(), out in any::()) + -> (usize, u64, usize, Fr, bool) { + (k, b, lookup_bits, rand_a, out) + } +} + +fn arb_indicator(max_size: usize) -> impl Strategy> { + vec(Just(0), max_size).prop_map(|val| val.iter().map(|&x| F::from(x)).collect::>()) +} + +fn check_idx_to_indicator(idx: Fr, len: usize, ind_witnesses: &[Fr]) -> bool { + // check that: + // the length of the witnes array is correct + // the sum of the witnesses is 1, indicting that there is only one index that is 1 + if ind_witnesses.len() != len + || ind_witnesses.iter().fold(Fr::zero(), |acc, val| acc + *val) != Fr::one() + { + return false; + } + + let idx_val = idx.get_lower_128() as usize; + + // Check that all indexes are zero except for the one at idx + for (i, v) in ind_witnesses.iter().enumerate() { + if i != idx_val && *v != Fr::zero() { + return false; + } + } + true +} + +// verify rand_output == a if sel == 1, rand_output == b if sel == 0 +fn check_select(a: Fr, b: Fr, sel: Fr, rand_output: Fr) -> bool { + if (sel == Fr::zero() && rand_output != b) || (sel == Fr::one() && rand_output != a) { + return false; + } + true +} + +fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[Fr]) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + // assign value to advice column before by assigning `idx` via ctx.load() -> use same method as ind_offsets to get offset + let dummy_idx = Witness(Fr::from(idx as u64)); + let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); + // get the offsets of the indicator cells for later 'pranking' + builder.config(k, Some(9)); + let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); + // prank the indicator cells + // TODO: prank the entire advice column with random values + for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { + builder.main(0).advice[*offset] = Assigned::Trivial(*witness); + } + // Get idx and indicator from advice column + // Apply check instance function to `idx` and `ind_witnesses` + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = check_idx_to_indicator(Fr::from(idx as u64), len, ind_witnesses); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_select( + k: usize, + a: QuantumCell, + b: QuantumCell, + sel: QuantumCell, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + // add select gate + let select = gate.select(builder.main(0), a, b, sel); + + // Get the offset of `select`s output for later 'pranking' + builder.config(k, Some(9)); + let select_offset = select.cell.unwrap().offset; + // Prank the output + builder.main(0).advice[select_offset] = Assigned::Trivial(rand_output); + + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of output + let is_valid_instance = check_select(*a.value(), *b.value(), *sel.value(), rand_output); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_instance, + // if the proof is invalid, ignore + Err(_) => !is_valid_instance, + } +} + +fn neg_test_select_by_indicator( + k: usize, + a: Vec>, + idx: usize, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let indicator = gate.idx_to_indicator(builder.main(0), Witness(Fr::from(idx as u64)), a.len()); + let a_idx = gate.select_by_indicator(builder.main(0), a.clone(), indicator); + builder.config(k, Some(9)); + + let a_idx_offset = a_idx.cell.unwrap().offset; + builder.main(0).advice[a_idx_offset] = Assigned::Trivial(rand_output); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // retrieve the value of a[idx] and check that it is equal to rand_output + let is_valid_witness = rand_output == *a[idx].value(); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_select_from_idx( + k: usize, + cells: Vec>, + idx: usize, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let idx_val = + gate.select_from_idx(builder.main(0), cells.clone(), Witness(Fr::from(idx as u64))); + builder.config(k, Some(9)); + + let idx_offset = idx_val.cell.unwrap().offset; + builder.main(0).advice[idx_offset] = Assigned::Trivial(rand_output); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = rand_output == *cells[idx].value(); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_inner_product( + k: usize, + a: Vec>, + b: Vec>, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let inner_product = gate.inner_product(builder.main(0), a.clone(), b.clone()); + builder.config(k, Some(9)); + + let inner_product_offset = inner_product.cell.unwrap().offset; + builder.main(0).advice[inner_product_offset] = Assigned::Trivial(rand_output); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = rand_output == test_ground_truths::inner_product_ground_truth(&(a, b)); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_inner_product_left_last( + k: usize, + a: Vec>, + b: Vec>, + rand_output: (Fr, Fr), +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let inner_product = gate.inner_product_left_last(builder.main(0), a.clone(), b.clone()); + builder.config(k, Some(9)); + + let inner_product_offset = + (inner_product.0.cell.unwrap().offset, inner_product.1.cell.unwrap().offset); + // prank the output cells + builder.main(0).advice[inner_product_offset.0] = Assigned::Trivial(rand_output.0); + builder.main(0).advice[inner_product_offset.1] = Assigned::Trivial(rand_output.1); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // (inner_product_ground_truth, a[a.len()-1]) + let inner_product_ground_truth = + test_ground_truths::inner_product_ground_truth(&(a.clone(), b)); + let is_valid_witness = + rand_output.0 == inner_product_ground_truth && rand_output.1 == *a[a.len() - 1].value(); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +// Range Check + +fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: Fr) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = RangeChip::default(lookup_bits); + + let a_witness = builder.main(0).load_witness(rand_a); + gate.range_check(builder.main(0), a_witness, range_bits); + + builder.config(k, Some(9)); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let correct = fe_to_biguint(&rand_a).bits() <= range_bits as u64; + + MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct +} + +// TODO: expand to prank output of is_less_than_safe() +fn neg_test_is_less_than_safe( + k: usize, + b: u64, + lookup_bits: usize, + rand_a: Fr, + prank_out: bool, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = RangeChip::default(lookup_bits); + let ctx = builder.main(0); + + let a_witness = ctx.load_witness(rand_a); // cannot prank this later because this witness will be copy-constrained + let out = gate.is_less_than_safe(ctx, a_witness, b); + + let out_idx = out.cell.unwrap().offset; + ctx.advice[out_idx] = Assigned::Trivial(Fr::from(prank_out)); + + builder.config(k, Some(9)); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // println!("rand_a: {rand_a:?}, b: {b:?}"); + let a_big = fe_to_biguint(&rand_a); + let is_lt = a_big < BigUint::from(b); + let correct = (is_lt == prank_out) + && (a_big.bits() as usize <= (bit_length(b) + lookup_bits - 1) / lookup_bits * lookup_bits); // circuit should always fail if `a` doesn't pass range check + MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct +} + +proptest! { + // Note setting the minimum value of k to 8 is intentional as it is the smallest value that will not cause an `out of columns` error. Should be noted that filtering by len * (number cells per iteration) < 2^k leads to the filtering of to many cases and the failure of the tests w/o any runs. + #[test] + fn prop_test_neg_idx_to_indicator((k, len, idx, witness_vals) in idx_to_indicator_strat((10,20),100)) { + prop_assert!(neg_test_idx_to_indicator(k, len, idx, witness_vals.as_slice())); + } + + #[test] + fn prop_test_neg_select((k, a, b, sel, rand_output) in select_strat((10,20))) { + prop_assert!(neg_test_select(k, a, b, sel, rand_output)); + } + + #[test] + fn prop_test_neg_select_by_indicator((k, a, idx, rand_output) in select_by_indicator_strat((12,20),100)) { + prop_assert!(neg_test_select_by_indicator(k, a, idx, rand_output)); + } + + #[test] + fn prop_test_neg_select_from_idx((k, cells, idx, rand_output) in select_from_idx_strat((10,20),100)) { + prop_assert!(neg_test_select_from_idx(k, cells, idx, rand_output)); + } + + #[test] + fn prop_test_neg_inner_product((k, a, b, rand_output) in inner_product_strat((10,20),100)) { + prop_assert!(neg_test_inner_product(k, a, b, rand_output)); + } + + #[test] + fn prop_test_neg_inner_product_left_last((k, a, b, rand_output) in inner_product_left_last_strat((10,20),100)) { + prop_assert!(neg_test_inner_product_left_last(k, a, b, rand_output)); + } + + #[test] + fn prop_test_neg_range_check((k, range_bits, lookup_bits, rand_a) in range_check_strat((10,23),90)) { + prop_assert!(neg_test_range_check(k, range_bits, lookup_bits, rand_a)); + } + + #[test] + fn prop_test_neg_is_less_than_safe((k, b, lookup_bits, rand_a, out) in is_less_than_safe_strat((10,20))) { + prop_assert!(neg_test_is_less_than_safe(k, b, lookup_bits, rand_a, out)); + } +} diff --git a/halo2-base/src/gates/tests/pos_prop_tests.rs b/halo2-base/src/gates/tests/pos_prop_tests.rs new file mode 100644 index 00000000..f110d12f --- /dev/null +++ b/halo2-base/src/gates/tests/pos_prop_tests.rs @@ -0,0 +1,326 @@ +use crate::gates::tests::{flex_gate_tests, range_gate_tests, test_ground_truths::*, Fr}; +use crate::utils::{bit_length, fe_to_biguint}; +use crate::{QuantumCell, QuantumCell::Witness}; +use proptest::{collection::vec, prelude::*}; +//TODO: implement Copy for rand witness and rand fr to allow for array creation +// create vec and convert to array??? +//TODO: implement arbitrary for fr using looks like you'd probably need to implement your own TestFr struct to implement Arbitrary: https://docs.rs/quickcheck/latest/quickcheck/trait.Arbitrary.html , can probably just hack it from Fr = [u64; 4] +prop_compose! { + pub fn rand_fr()(val in any::()) -> Fr { + Fr::from(val) + } +} + +prop_compose! { + pub fn rand_witness()(val in any::()) -> QuantumCell { + Witness(Fr::from(val)) + } +} + +prop_compose! { + pub fn sum_products_with_coeff_and_var_strat(max_length: usize)(val in vec((rand_fr(), rand_witness(), rand_witness()), 1..=max_length), witness in rand_witness()) -> (Vec<(Fr, QuantumCell, QuantumCell)>, QuantumCell) { + (val, witness) + } +} + +prop_compose! { + pub fn rand_bin_witness()(val in prop::sample::select(vec![Fr::zero(), Fr::one()])) -> QuantumCell { + Witness(val) + } +} + +prop_compose! { + pub fn rand_fr_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> Fr { + Fr::from(val) + } +} + +prop_compose! { + pub fn rand_witness_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> QuantumCell { + Witness(Fr::from(val)) + } +} + +// LEsson here 0..2^range_bits fails with 'Uniform::new called with `low >= high` +// therfore to still have a range of 0..2^range_bits we need on a mod it by 2^range_bits +// note k > lookup_bits +prop_compose! { + fn range_check_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_range_bits: u32) + (range_bits in 2..=max_range_bits, k in k_lo..=k_hi) + (k in Just(k), lookup_bits in min_lookup_bits..(k-3), a in rand_fr_range(0, range_bits), + range_bits in Just(range_bits)) + -> (usize, usize, Fr, usize) { + (k, lookup_bits, a, range_bits as usize) + } +} + +prop_compose! { + fn check_less_than_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_num_bits: usize) + (num_bits in 2..max_num_bits, k in k_lo..=k_hi) + (k in Just(k), a in rand_witness_range(0, num_bits as u32), b in rand_witness_range(0, num_bits as u32), + num_bits in Just(num_bits), lookup_bits in min_lookup_bits..k) + -> (usize, usize, QuantumCell, QuantumCell, usize) { + (k, lookup_bits, a, b, num_bits) + } +} + +prop_compose! { + fn check_less_than_safe_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) + (k in k_lo..=k_hi) + (k in Just(k), b in any::(), a in rand_fr(), lookup_bits in min_lookup_bits..k) + -> (usize, usize, Fr, u64) { + (k, lookup_bits, a, b) + } +} + +proptest! { + + // Flex Gate Positive Tests + #[test] + fn prop_test_add(input in vec(rand_witness(), 2)) { + let ground_truth = add_ground_truth(input.as_slice()); + let result = flex_gate_tests::test_add(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_sub(input in vec(rand_witness(), 2)) { + let ground_truth = sub_ground_truth(input.as_slice()); + let result = flex_gate_tests::test_sub(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_neg(input in rand_witness()) { + let ground_truth = neg_ground_truth(input); + let result = flex_gate_tests::test_neg(input); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul(inputs in vec(rand_witness(), 2)) { + let ground_truth = mul_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_mul(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul_add(inputs in vec(rand_witness(), 3)) { + let ground_truth = mul_add_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_mul_add(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul_not(inputs in vec(rand_witness(), 2)) { + let ground_truth = mul_not_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_mul_not(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_assert_bit(input in rand_fr()) { + let ground_truth = input == Fr::one() || input == Fr::zero(); + let result = flex_gate_tests::test_assert_bit(input).is_ok(); + prop_assert_eq!(result, ground_truth); + } + + // Note: due to unwrap after inversion this test will fail if the denominator is zero so we want to test for that. Therefore we do not filter for zero values. + #[test] + fn prop_test_div_unsafe(inputs in vec(rand_witness().prop_filter("Input cannot be 0",|x| *x.value() != Fr::zero()), 2)) { + let ground_truth = div_unsafe_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_div_unsafe(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_assert_is_const(input in rand_fr()) { + flex_gate_tests::test_assert_is_const(&[input; 2]); + } + + #[test] + fn prop_test_inner_product(inputs in (vec(rand_witness(), 0..=100), vec(rand_witness(), 0..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_ground_truth(&inputs); + let result = flex_gate_tests::test_inner_product(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_inner_product_left_last(inputs in (vec(rand_witness(), 1..=100), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_left_last_ground_truth(&inputs); + let result = flex_gate_tests::test_inner_product_left_last(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_inner_product_with_sums(inputs in (vec(rand_witness(), 0..=10), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_with_sums_ground_truth(&inputs); + let result = flex_gate_tests::test_inner_product_with_sums(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_sum_products_with_coeff_and_var(input in sum_products_with_coeff_and_var_strat(100)) { + let expected = sum_products_with_coeff_and_var_ground_truth(&input); + let output = flex_gate_tests::test_sum_products_with_coeff_and_var(input); + prop_assert_eq!(expected, output); + } + + #[test] + fn prop_test_and(inputs in vec(rand_witness(), 2)) { + let ground_truth = and_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_and(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_not(input in rand_witness()) { + let ground_truth = not_ground_truth(&input); + let result = flex_gate_tests::test_not(input); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select(vals in vec(rand_witness(), 2), sel in rand_bin_witness()) { + let inputs = vec![vals[0], vals[1], sel]; + let ground_truth = select_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_select(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_or_and(inputs in vec(rand_witness(), 3)) { + let ground_truth = or_and_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_or_and(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_idx_to_indicator(input in (rand_witness(), 1..=16_usize)) { + let ground_truth = idx_to_indicator_ground_truth(input); + let result = flex_gate_tests::test_idx_to_indicator((input.0, input.1)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select_by_indicator(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { + let ground_truth = select_by_indicator_ground_truth(&inputs); + let result = flex_gate_tests::test_select_by_indicator(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select_from_idx(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { + let ground_truth = select_from_idx_ground_truth(&inputs); + let result = flex_gate_tests::test_select_from_idx(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_zero(x in rand_fr()) { + let ground_truth = is_zero_ground_truth(x); + let result = flex_gate_tests::test_is_zero(x); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_equal(inputs in vec(rand_witness(), 2)) { + let ground_truth = is_equal_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_is_equal(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_num_to_bits(num in any::()) { + let mut tmp = num; + let mut bits = vec![]; + if num == 0 { + bits.push(0); + } + while tmp > 0 { + bits.push(tmp & 1); + tmp /= 2; + } + let result = flex_gate_tests::test_num_to_bits((Fr::from(num), bits.len())); + prop_assert_eq!(bits.into_iter().map(Fr::from).collect::>(), result); + } + + /* + #[test] + fn prop_test_lagrange_eval(inputs in vec(rand_fr(), 3)) { + } + */ + + #[test] + fn prop_test_get_field_element(n in any::()) { + let ground_truth = get_field_element_ground_truth(n); + let result = flex_gate_tests::test_get_field_element::(n); + prop_assert_eq!(result, ground_truth); + } + + // Range Check Property Tests + + #[test] + fn prop_test_is_less_than(a in rand_witness(), b in any::().prop_filter("not zero", |&x| x != 0), + lookup_bits in 4..=16_usize) { + let bits = std::cmp::max(fe_to_biguint(a.value()).bits() as usize, bit_length(b)); + let ground_truth = is_less_than_ground_truth((*a.value(), Fr::from(b))); + let result = range_gate_tests::test_is_less_than(([a, Witness(Fr::from(b))], bits, lookup_bits)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_less_than_safe(a in rand_fr().prop_filter("not zero", |&x| x != Fr::zero()), + b in any::().prop_filter("not zero", |&x| x != 0), + lookup_bits in 4..=16_usize) { + prop_assume!(fe_to_biguint(&a).bits() as usize <= bit_length(b)); + let ground_truth = is_less_than_ground_truth((a, Fr::from(b))); + let result = range_gate_tests::test_is_less_than_safe((a, b, lookup_bits)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_div_mod(inputs in (rand_witness().prop_filter("Non-zero num", |x| *x.value() != Fr::zero()), any::().prop_filter("Non-zero divisor", |x| *x != 0u64), 1..=16_usize)) { + let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); + let result = range_gate_tests::test_div_mod((inputs.0, inputs.1, inputs.2)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_get_last_bit(input in rand_fr(), pad_bits in 0..10usize) { + let ground_truth = get_last_bit_ground_truth(input); + let bits = fe_to_biguint(&input).bits() as usize + pad_bits; + let result = range_gate_tests::test_get_last_bit((input, bits)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_div_mod_var(inputs in (rand_witness(), any::(), 1..=16_usize, 1..=16_usize)) { + let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); + let result = range_gate_tests::test_div_mod_var((inputs.0, Witness(Fr::from(inputs.1)), inputs.2, inputs.3)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_range_check((k, lookup_bits, a, range_bits) in range_check_strat((14,24), 3, 63)) { + prop_assert_eq!(range_gate_tests::test_range_check(k, lookup_bits, a, range_bits), ()); + } + + #[test] + fn prop_test_check_less_than((k, lookup_bits, a, b, num_bits) in check_less_than_strat((14,24), 3, 10)) { + prop_assume!(a.value() < b.value()); + prop_assert_eq!(range_gate_tests::test_check_less_than(k, lookup_bits, a, b, num_bits), ()); + } + + #[test] + fn prop_test_check_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { + prop_assume!(a < Fr::from(b)); + prop_assert_eq!(range_gate_tests::test_check_less_than_safe(k, lookup_bits, a, b), ()); + } + + #[test] + fn prop_test_check_big_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { + prop_assume!(a < Fr::from(b)); + prop_assert_eq!(range_gate_tests::test_check_big_less_than_safe(k, lookup_bits, a, b), ()); + } +} diff --git a/halo2-base/src/gates/tests/range_gate_tests.rs b/halo2-base/src/gates/tests/range_gate_tests.rs new file mode 100644 index 00000000..c781af2e --- /dev/null +++ b/halo2-base/src/gates/tests/range_gate_tests.rs @@ -0,0 +1,155 @@ +use std::env::set_var; + +use super::*; +use crate::halo2_proofs::dev::MockProver; +use crate::utils::{biguint_to_fe, ScalarField}; +use crate::QuantumCell::Witness; +use crate::{ + gates::{ + builder::{GateThreadBuilder, RangeCircuitBuilder}, + range::{RangeChip, RangeInstructions}, + }, + utils::BigPrimeField, + QuantumCell, +}; +use num_bigint::BigUint; +use test_case::test_case; + +#[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] +pub fn test_range_check(k: usize, lookup_bits: usize, a_val: F, range_bits: usize) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.assign_witnesses([a_val])[0]; + chip.range_check(ctx, a, range_bits); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(12, 10, Witness(Fr::zero()), Witness(Fr::one()), 64; "check_less_than() pos")] +pub fn test_check_less_than( + k: usize, + lookup_bits: usize, + a: QuantumCell, + b: QuantumCell, + num_bits: usize, +) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + chip.check_less_than(ctx, a, b, num_bits); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(10, 8, Fr::zero(), 1; "check_less_than_safe() pos")] +pub fn test_check_less_than_safe(k: usize, lookup_bits: usize, a_val: F, b: u64) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.assign_witnesses([a_val])[0]; + chip.check_less_than_safe(ctx, a, b); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(10, 8, Fr::zero(), 1; "check_big_less_than_safe() pos")] +pub fn test_check_big_less_than_safe( + k: usize, + lookup_bits: usize, + a_val: F, + b: u64, +) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.assign_witnesses([a_val])[0]; + chip.check_big_less_than_safe(ctx, a, BigUint::from(b)); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(([0, 1].map(Fr::from).map(Witness), 3, 12) => Fr::from(1) ; "is_less_than() pos")] +pub fn test_is_less_than( + (inputs, bits, lookup_bits): ([QuantumCell; 2], usize, usize), +) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = chip.is_less_than(ctx, inputs[0], inputs[1], bits); + *a.value() +} + +#[test_case((Fr::zero(), 3, 3) => Fr::from(1) ; "is_less_than_safe() pos")] +pub fn test_is_less_than_safe((a, b, lookup_bits): (F, u64, usize)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.load_witness(a); + let lt = chip.is_less_than_safe(ctx, a, b); + *lt.value() +} + +#[test_case((biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize, 8) => Fr::from(1) ; "is_big_less_than_safe() pos")] +pub fn test_is_big_less_than_safe( + (a, b, lookup_bits): (F, BigUint, usize), +) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.load_witness(a); + let b = chip.is_big_less_than_safe(ctx, a, b); + *b.value() +} + +#[test_case((Witness(Fr::one()), 1, 2) => (Fr::one(), Fr::zero()) ; "div_mod() pos")] +pub fn test_div_mod( + inputs: (QuantumCell, u64, usize), +) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(3); + let a = chip.div_mod(ctx, inputs.0, BigUint::from(inputs.1), inputs.2); + (*a.0.value(), *a.1.value()) +} + +#[test_case((Fr::from(3), 8) => Fr::one() ; "get_last_bit(): 3, 8 bits")] +#[test_case((Fr::from(3), 2) => Fr::one() ; "get_last_bit(): 3, 2 bits")] +#[test_case((Fr::from(0), 2) => Fr::zero() ; "get_last_bit(): 0")] +#[test_case((Fr::from(1), 2) => Fr::one() ; "get_last_bit(): 1")] +#[test_case((Fr::from(2), 2) => Fr::zero() ; "get_last_bit(): 2")] +pub fn test_get_last_bit((a, bits): (F, usize)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(3); + let a = ctx.load_witness(a); + let b = chip.get_last_bit(ctx, a, bits); + *b.value() +} + +#[test_case((Witness(Fr::from(3)), Witness(Fr::from(2)), 3, 3) => (Fr::one(), Fr::one()) ; "div_mod_var() pos")] +pub fn test_div_mod_var( + inputs: (QuantumCell, QuantumCell, usize, usize), +) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(3); + let a = chip.div_mod_var(ctx, inputs.0, inputs.1, inputs.2, inputs.3); + (*a.0.value(), *a.1.value()) +} diff --git a/halo2-base/src/gates/tests/test_ground_truths.rs b/halo2-base/src/gates/tests/test_ground_truths.rs new file mode 100644 index 00000000..894ff8c5 --- /dev/null +++ b/halo2-base/src/gates/tests/test_ground_truths.rs @@ -0,0 +1,190 @@ +use num_integer::Integer; + +use crate::utils::biguint_to_fe; +use crate::utils::fe_to_biguint; +use crate::utils::BigPrimeField; +use crate::utils::ScalarField; +use crate::QuantumCell; + +// Ground truth functions + +// Flex Gate Ground Truths + +pub fn add_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() + *inputs[1].value() +} + +pub fn sub_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() - *inputs[1].value() +} + +pub fn neg_ground_truth(input: QuantumCell) -> F { + -(*input.value()) +} + +pub fn mul_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() * *inputs[1].value() +} + +pub fn mul_add_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() * *inputs[1].value() + *inputs[2].value() +} + +pub fn mul_not_ground_truth(inputs: &[QuantumCell]) -> F { + (F::one() - *inputs[0].value()) * *inputs[1].value() +} + +pub fn div_unsafe_ground_truth(inputs: &[QuantumCell]) -> F { + inputs[1].value().invert().unwrap() * *inputs[0].value() +} + +pub fn inner_product_ground_truth( + inputs: &(Vec>, Vec>), +) -> F { + inputs + .0 + .iter() + .zip(inputs.1.iter()) + .fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b.value())) +} + +pub fn inner_product_left_last_ground_truth( + inputs: &(Vec>, Vec>), +) -> (F, F) { + let product = inner_product_ground_truth(inputs); + let last = *inputs.0.last().unwrap().value(); + (product, last) +} + +pub fn inner_product_with_sums_ground_truth( + input: &(Vec>, Vec>), +) -> Vec { + let (a, b) = &input; + let mut result = Vec::new(); + let mut sum = F::zero(); + // TODO: convert to fold + for (ai, bi) in a.iter().zip(b) { + let product = *ai.value() * *bi.value(); + sum += product; + result.push(sum); + } + result +} + +pub fn sum_products_with_coeff_and_var_ground_truth( + input: &(Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), +) -> F { + let expected = input.0.iter().fold(F::zero(), |acc, (coeff, cell1, cell2)| { + acc + *coeff * *cell1.value() * *cell2.value() + }) + *input.1.value(); + expected +} + +pub fn and_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() * *inputs[1].value() +} + +pub fn not_ground_truth(a: &QuantumCell) -> F { + F::one() - *a.value() +} + +pub fn select_ground_truth(inputs: &[QuantumCell]) -> F { + (*inputs[0].value() - inputs[1].value()) * *inputs[2].value() + *inputs[1].value() +} + +pub fn or_and_ground_truth(inputs: &[QuantumCell]) -> F { + let bc_val = *inputs[1].value() * inputs[2].value(); + bc_val + inputs[0].value() - bc_val * inputs[0].value() +} + +pub fn idx_to_indicator_ground_truth(inputs: (QuantumCell, usize)) -> Vec { + let (idx, size) = inputs; + let mut indicator = vec![F::zero(); size]; + let mut idx_value = size + 1; + for i in 0..size as u64 { + if F::from(i) == *idx.value() { + idx_value = i as usize; + break; + } + } + if idx_value < size { + indicator[idx_value] = F::one(); + } + indicator +} + +pub fn select_by_indicator_ground_truth( + inputs: &(Vec>, QuantumCell), +) -> F { + let mut idx_value = inputs.0.len() + 1; + let mut indicator = vec![F::zero(); inputs.0.len()]; + for i in 0..inputs.0.len() as u64 { + if F::from(i) == *inputs.1.value() { + idx_value = i as usize; + break; + } + } + if idx_value < inputs.0.len() { + indicator[idx_value] = F::one(); + } + // take cross product of indicator and inputs.0 + inputs.0.iter().zip(indicator.iter()).fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b)) +} + +pub fn select_from_idx_ground_truth( + inputs: &(Vec>, QuantumCell), +) -> F { + let idx = inputs.1.value(); + // Since F does not implement From, we have to iterate and find the matching index + for i in 0..inputs.0.len() as u64 { + if F::from(i) == *idx { + return *inputs.0[i as usize].value(); + } + } + F::zero() +} + +pub fn is_zero_ground_truth(x: F) -> F { + if x.is_zero().into() { + F::one() + } else { + F::zero() + } +} + +pub fn is_equal_ground_truth(inputs: &[QuantumCell]) -> F { + if inputs[0].value() == inputs[1].value() { + F::one() + } else { + F::zero() + } +} + +/* +pub fn lagrange_eval_ground_truth(inputs: &[F]) -> (F, F) { +} +*/ + +pub fn get_field_element_ground_truth(n: u64) -> F { + F::from(n) +} + +// Range Chip Ground Truths + +pub fn is_less_than_ground_truth(inputs: (F, F)) -> F { + if inputs.0 < inputs.1 { + F::one() + } else { + F::zero() + } +} + +pub fn div_mod_ground_truth(inputs: (F, u64)) -> (F, F) { + let a = fe_to_biguint(&inputs.0); + let (div, rem) = a.div_mod_floor(&inputs.1.into()); + (biguint_to_fe(&div), biguint_to_fe(&rem)) +} + +pub fn get_last_bit_ground_truth(input: F) -> F { + F::from(input.get_lower_32() & 1 == 1) +} diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 45224578..289d4057 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -44,6 +44,7 @@ pub mod utils; /// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-axiom")] pub const SKIP_FIRST_PASS: bool = false; +/// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-pse")] pub const SKIP_FIRST_PASS: bool = true; diff --git a/halo2-base/src/utils.rs b/halo2-base/src/utils.rs index ebed5db4..2856b267 100644 --- a/halo2-base/src/utils.rs +++ b/halo2-base/src/utils.rs @@ -13,13 +13,16 @@ use num_traits::{One, Zero}; pub trait BigPrimeField: ScalarField { /// Converts a slice of [u64] to [BigPrimeField] /// * `val`: the slice of u64 - /// Assumes val.len() <= 4 + /// + /// # Assumptions + /// * `val` has the correct length for the implementation + /// * The integer value of `val` is already less than the modulus of `Self` fn from_u64_digits(val: &[u64]) -> Self; } #[cfg(feature = "halo2-axiom")] impl BigPrimeField for F where - F: FieldExt + Hash + Into<[u64; 4]> + From<[u64; 4]>, + F: ScalarField + From<[u64; 4]>, // Assume [u64; 4] is little-endian. We only implement ScalarField when this is true. { #[inline(always)] fn from_u64_digits(val: &[u64]) -> Self { @@ -30,10 +33,9 @@ where } } -/// Helper trait to convert to and from a [ScalarField] by decomposing its an field element into [u64] limbs. -/// -/// Note: Since the number of bits necessary to represent a field element is larger than the number of bits in a u64, we decompose the bit representation of the field element into multiple [u64] values e.g. `limbs`. -#[cfg(feature = "halo2-axiom")] +/// Helper trait to represent a field element that can be converted into [u64] limbs. +/// +/// Note: Since the number of bits necessary to represent a field element is larger than the number of bits in a u64, we decompose the integer representation of the field element into multiple [u64] values e.g. `limbs`. pub trait ScalarField: FieldExt + Hash { /// Returns the base `2bit_len` little endian representation of the [ScalarField] element up to `num_limbs` number of limbs (truncates any extra limbs). /// @@ -41,27 +43,26 @@ pub trait ScalarField: FieldExt + Hash { /// * `num_limbs`: number of limbs to return /// * `bit_len`: number of bits in each limb fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec; -} -#[cfg(feature = "halo2-axiom")] -impl ScalarField for F -where - F: FieldExt + Hash + Into<[u64; 4]>, -{ - #[inline(always)] - fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { - // Basically same as `to_repr` but does not go further into bytes - let tmp: [u64; 4] = self.into(); - decompose_u64_digits_to_limbs(tmp, num_limbs, bit_len) + + /// Returns the little endian byte representation of the element. + fn to_bytes_le(&self) -> Vec; + + /// Creates a field element from a little endian byte representation. + /// + /// The default implementation assumes that `PrimeField::from_repr` is implemented for little-endian. + /// It should be overriden if this is not the case. + fn from_bytes_le(bytes: &[u8]) -> Self { + let mut repr = Self::Repr::default(); + repr.as_mut()[..bytes.len()].copy_from_slice(bytes); + Self::from_repr(repr).unwrap() } } +// See below for implementations // Later: will need to separate BigPrimeField from ScalarField when Goldilocks is introduced #[cfg(feature = "halo2-pse")] -pub trait BigPrimeField = FieldExt + Hash; - -#[cfg(feature = "halo2-pse")] -pub trait ScalarField = FieldExt + Hash; +pub trait BigPrimeField = FieldExt + ScalarField; /// Converts an [Iterator] of u64 digits into `number_of_limbs` limbs of `bit_len` bits returned as a [Vec]. /// @@ -122,10 +123,10 @@ pub fn bit_length(x: u64) -> usize { } /// Returns the ceiling of the base 2 logarithm of `x`. -/// -/// Assumes x != 0 +/// +/// `log2_ceil(0)` returns 0. pub fn log2_ceil(x: u64) -> usize { - (u64::BITS - x.leading_zeros() - (x & (x - 1) == 0) as u32) as usize + (u64::BITS - x.leading_zeros()) as usize - usize::from(x.is_power_of_two()) } /// Returns the modulus of [BigPrimeField]. @@ -141,6 +142,9 @@ pub fn power_of_two(n: usize) -> F { /// Converts an immutable reference to [BigUint] to a [BigPrimeField]. /// * `e`: immutable reference to [BigUint] +/// +/// # Assumptions: +/// * `e` is less than the modulus of `F` pub fn biguint_to_fe(e: &BigUint) -> F { #[cfg(feature = "halo2-axiom")] { @@ -149,15 +153,16 @@ pub fn biguint_to_fe(e: &BigUint) -> F { #[cfg(feature = "halo2-pse")] { - let mut repr = F::Repr::default(); let bytes = e.to_bytes_le(); - repr.as_mut()[..bytes.len()].copy_from_slice(&bytes); - F::from_repr(repr).unwrap() + F::from_bytes_le(&bytes) } } /// Converts an immutable reference to [BigInt] to a [BigPrimeField]. /// * `e`: immutable reference to [BigInt] +/// +/// # Assumptions: +/// * The absolute value of `e` is less than the modulus of `F` pub fn bigint_to_fe(e: &BigInt) -> F { #[cfg(feature = "halo2-axiom")] { @@ -171,9 +176,7 @@ pub fn bigint_to_fe(e: &BigInt) -> F { #[cfg(feature = "halo2-pse")] { let (sign, bytes) = e.to_bytes_le(); - let mut repr = F::Repr::default(); - repr.as_mut()[..bytes.len()].copy_from_slice(&bytes); - let f_abs = F::from_repr(repr).unwrap(); + let f_abs = F::from_bytes_le(&bytes); if sign == Sign::Minus { -f_abs } else { @@ -182,14 +185,17 @@ pub fn bigint_to_fe(e: &BigInt) -> F { } } -/// Converts an immutable reference to an PrimeField element into a [BigUint] element. +/// Converts an immutable reference to an PrimeField element into a [BigUint] element. /// * `fe`: immutable reference to PrimeField element to convert -pub fn fe_to_biguint(fe: &F) -> BigUint { - BigUint::from_bytes_le(fe.to_repr().as_ref()) +pub fn fe_to_biguint(fe: &F) -> BigUint { + BigUint::from_bytes_le(fe.to_bytes_le().as_ref()) } -/// Converts an immutable reference to a [BigPrimeField] element into a [BigInt] element. -/// * `fe`: immutable reference to [BigPrimeField] element to convert +/// Converts a [BigPrimeField] element into a [BigInt] element by sending `fe` in `[0, F::modulus())` to +/// ```ignore +/// fe, if fe < F::modulus() / 2 +/// fe - F::modulus(), otherwise +/// ``` pub fn fe_to_bigint(fe: &F) -> BigInt { // TODO: `F` should just have modulus as lazy_static or something let modulus = modulus::(); @@ -202,7 +208,7 @@ pub fn fe_to_bigint(fe: &F) -> BigInt { } /// Decomposes an immutable reference to a [BigPrimeField] element into `number_of_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs. -/// +/// /// Assumes `bit_len < 128`. /// * `e`: immutable reference to [BigPrimeField] element to decompose /// * `number_of_limbs`: number of limbs to decompose `e` into @@ -243,6 +249,8 @@ pub fn decompose_fe_to_u64_limbs( /// * `e`: immutable reference to [BigInt] to decompose /// * `num_limbs`: number of limbs to decompose `e` into /// * `bit_len`: number of bits in each limb +/// +/// Truncates to `num_limbs` limbs if `e` is too large. pub fn decompose_biguint( e: &BigUint, num_limbs: usize, @@ -282,7 +290,7 @@ pub fn decompose_biguint( } /// Decomposes an immutable reference to a [BigInt] into `num_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs. -/// +/// /// Assumes `bit_len < 128`. /// * `e`: immutable reference to `BigInt` to decompose /// * `num_limbs`: number of limbs to decompose `e` into @@ -296,7 +304,7 @@ pub fn decompose_bigint(e: &BigInt, num_limbs: usize, bit_len: } /// Decomposes an immutable reference to a [BigInt] into `num_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs wrapped in [Value]. -/// +/// /// Assumes `bit_len` < 128. /// * `e`: immutable reference to `BigInt` to decompose /// * `num_limbs`: number of limbs to decompose `e` into @@ -309,7 +317,7 @@ pub fn decompose_bigint_option( value.map(|e| decompose_bigint(e, number_of_limbs, bit_len)).transpose_vec(number_of_limbs) } -/// Wraps the internal value of `value` in an [Option]. +/// Wraps the internal value of `value` in an [Option]. /// If the value is [None], then the function returns [None]. /// * `value`: Value to convert. pub fn value_to_option(value: Value) -> Option { @@ -332,6 +340,7 @@ pub fn compose(input: Vec, bit_len: usize) -> BigUint { #[cfg(feature = "halo2-axiom")] pub use halo2_proofs_axiom::halo2curves::CurveAffineExt; +/// Helper trait #[cfg(feature = "halo2-pse")] pub trait CurveAffineExt: CurveAffine { /// Unlike the `Coordinates` trait, this just returns the raw affine (X, Y) coordinantes without checking `is_on_curve` @@ -343,6 +352,67 @@ pub trait CurveAffineExt: CurveAffine { #[cfg(feature = "halo2-pse")] impl CurveAffineExt for C {} +mod scalar_field_impls { + use super::{decompose_u64_digits_to_limbs, ScalarField}; + use crate::halo2_proofs::halo2curves::{ + bn256::{Fq as bn254Fq, Fr as bn254Fr}, + secp256k1::{Fp as secpFp, Fq as secpFq}, + }; + #[cfg(feature = "halo2-pse")] + use ff::PrimeField; + + /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro + /// to implement the trait for each field. + #[cfg(feature = "halo2-axiom")] + #[macro_export] + macro_rules! impl_scalar_field { + ($field:ident) => { + impl ScalarField for $field { + #[inline(always)] + fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { + // Basically same as `to_repr` but does not go further into bytes + let tmp: [u64; 4] = self.into(); + decompose_u64_digits_to_limbs(tmp, num_limbs, bit_len) + } + + #[inline(always)] + fn to_bytes_le(&self) -> Vec { + let tmp: [u64; 4] = (*self).into(); + tmp.iter().flat_map(|x| x.to_le_bytes()).collect() + } + } + }; + } + + /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro + /// to implement the trait for each field. + #[cfg(feature = "halo2-pse")] + #[macro_export] + macro_rules! impl_scalar_field { + ($field:ident) => { + impl ScalarField for $field { + #[inline(always)] + fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { + let bytes = self.to_repr(); + let digits = (0..4) + .map(|i| u64::from_le_bytes(bytes[i * 8..(i + 1) * 8].try_into().unwrap())); + decompose_u64_digits_to_limbs(digits, num_limbs, bit_len) + } + + #[inline(always)] + fn to_bytes_le(&self) -> Vec { + self.to_repr().to_vec() + } + } + }; + } + + impl_scalar_field!(bn254Fr); + impl_scalar_field!(bn254Fq); + impl_scalar_field!(secpFp); + impl_scalar_field!(secpFq); +} + /// Module for reading parameters for Halo2 proving system from the file system. pub mod fs { use std::{ @@ -401,7 +471,7 @@ pub mod fs { } } - /// Generates the SRS for the KZG scheme and writes it to a file found in "./params/kzg_bn2_{k}.srs` or `{dir}/kzg_bn254_{k}.srs` if `PARAMS_DIR` env var is specified, creates a file it if it does not exist" + /// Generates the SRS for the KZG scheme and writes it to a file found in "./params/kzg_bn2_{k}.srs` or `{dir}/kzg_bn254_{k}.srs` if `PARAMS_DIR` env var is specified, creates a file it if it does not exist" /// * `k`: degree that expresses the size of circuit (i.e., 2^k is the number of rows in the circuit) pub fn gen_srs(k: u32) -> ParamsKZG { read_or_create_srs::(k, |k| { @@ -481,4 +551,9 @@ mod tests { } } } + + #[test] + fn test_log2_ceil_zero() { + assert_eq!(log2_ceil(0), 0); + } } diff --git a/halo2-ecc/benches/fp_mul.rs b/halo2-ecc/benches/fp_mul.rs index c2de04ce..48351c45 100644 --- a/halo2-ecc/benches/fp_mul.rs +++ b/halo2-ecc/benches/fp_mul.rs @@ -44,7 +44,7 @@ fn fp_mul_bench( let range = RangeChip::::default(lookup_bits); let chip = FpChip::::new(&range, limb_bits, num_limbs); - let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, FpChip::::fe_to_witness(&x))); + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); for _ in 0..2857 { chip.mul(ctx, &a, &b); } diff --git a/halo2-ecc/benches/msm.rs b/halo2-ecc/benches/msm.rs index a1b0a301..3a98ee38 100644 --- a/halo2-ecc/benches/msm.rs +++ b/halo2-ecc/benches/msm.rs @@ -59,8 +59,10 @@ fn msm_bench( let ctx = builder.main(0); let scalars_assigned = scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); - let bases_assigned = - bases.iter().map(|base| ecc_chip.load_private(ctx, (base.x, base.y))).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); ecc_chip.variable_base_msm_in::( builder, diff --git a/halo2-ecc/src/bigint/add_no_carry.rs b/halo2-ecc/src/bigint/add_no_carry.rs index e7d920a8..19feb35d 100644 --- a/halo2-ecc/src/bigint/add_no_carry.rs +++ b/halo2-ecc/src/bigint/add_no_carry.rs @@ -1,35 +1,37 @@ use super::{CRTInteger, OverflowInteger}; use halo2_base::{gates::GateInstructions, utils::ScalarField, Context}; +use itertools::Itertools; use std::cmp::max; +/// # Assumptions +/// * `a, b` have same number of limbs pub fn assign( gate: &impl GateInstructions, ctx: &mut Context, - a: &OverflowInteger, - b: &OverflowInteger, + a: OverflowInteger, + b: OverflowInteger, ) -> OverflowInteger { - debug_assert_eq!(a.limbs.len(), b.limbs.len()); - let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(&a_limb, &b_limb)| gate.add(ctx, a_limb, b_limb)) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.add(ctx, a_limb, b_limb)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) + OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) } +/// # Assumptions +/// * `a, b` have same number of limbs // pass by reference to avoid cloning the BigInt in CRTInteger, unclear if this is optimal pub fn crt( gate: &impl GateInstructions, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: CRTInteger, + b: CRTInteger, ) -> CRTInteger { - debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation); + let out_trunc = assign(gate, ctx, a.truncation, b.truncation); let out_native = gate.add(ctx, a.native, b.native); - let out_val = &a.value + &b.value; - CRTInteger::construct(out_trunc, out_native, out_val) + let out_val = a.value + b.value; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/big_is_equal.rs b/halo2-ecc/src/bigint/big_is_equal.rs index f64a3fae..78626b22 100644 --- a/halo2-ecc/src/bigint/big_is_equal.rs +++ b/halo2-ecc/src/bigint/big_is_equal.rs @@ -1,45 +1,29 @@ -use super::{CRTInteger, OverflowInteger}; +use super::ProperUint; use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use itertools::Itertools; -/// Given OverflowInteger's `a` and `b` of the same shape, +/// Given [`ProperUint`]s `a` and `b` with the same number of limbs, /// returns whether `a == b`. +/// +/// # Assumptions: +/// * `a, b` have the same number of limbs. +/// * The number of limbs is nonzero. pub fn assign( gate: &impl GateInstructions, ctx: &mut Context, - a: &OverflowInteger, - b: &OverflowInteger, + a: impl Into>, + b: impl Into>, ) -> AssignedValue { - let k = a.limbs.len(); - debug_assert_eq!(k, b.limbs.len()); - debug_assert_ne!(k, 0); + let a = a.into(); + let b = b.into(); + debug_assert!(!a.0.is_empty()); - let mut a_limbs = a.limbs.iter(); - let mut b_limbs = b.limbs.iter(); - let mut partial = gate.is_equal(ctx, *a_limbs.next().unwrap(), *b_limbs.next().unwrap()); - for (&a_limb, &b_limb) in a_limbs.zip(b_limbs) { + let mut a_limbs = a.0.into_iter(); + let mut b_limbs = b.0.into_iter(); + let mut partial = gate.is_equal(ctx, a_limbs.next().unwrap(), b_limbs.next().unwrap()); + for (a_limb, b_limb) in a_limbs.zip_eq(b_limbs) { let eq_limb = gate.is_equal(ctx, a_limb, b_limb); partial = gate.and(ctx, eq_limb, partial); } partial } - -pub fn wrapper( - gate: &impl GateInstructions, - ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, -) -> AssignedValue { - assign(gate, ctx, &a.truncation, &b.truncation) -} - -pub fn crt( - gate: &impl GateInstructions, - ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, -) -> AssignedValue { - debug_assert_eq!(a.value, b.value); - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation); - let out_native = gate.is_equal(ctx, a.native, b.native); - gate.and(ctx, out_trunc, out_native) -} diff --git a/halo2-ecc/src/bigint/big_is_zero.rs b/halo2-ecc/src/bigint/big_is_zero.rs index d6b03cd5..aa67c842 100644 --- a/halo2-ecc/src/bigint/big_is_zero.rs +++ b/halo2-ecc/src/bigint/big_is_zero.rs @@ -1,44 +1,53 @@ -use super::{CRTInteger, OverflowInteger}; +use super::{OverflowInteger, ProperCrtUint, ProperUint}; use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; -/// assume you know that the limbs of `a` are all in [0, 2^{a.max_limb_bits}) +/// # Assumptions +/// * `a` has nonzero number of limbs +/// * The limbs of `a` are all in [0, 2a.max_limb_bits) +/// * a.limbs.len() * 2a.max_limb_bits ` is less than modulus of `F` pub fn positive( gate: &impl GateInstructions, ctx: &mut Context, - a: &OverflowInteger, + a: OverflowInteger, ) -> AssignedValue { let k = a.limbs.len(); - debug_assert_ne!(k, 0); - debug_assert!(a.max_limb_bits as u32 + k.ilog2() < F::CAPACITY); + assert_ne!(k, 0); + assert!(a.max_limb_bits as u32 + k.ilog2() < F::CAPACITY); - let sum = gate.sum(ctx, a.limbs.iter().copied()); + let sum = gate.sum(ctx, a.limbs); gate.is_zero(ctx, sum) } -// given OverflowInteger `a`, returns whether `a == 0` +/// Given ProperUint `a`, returns 1 iff every limb of `a` is zero. Returns 0 otherwise. +/// +/// It is almost always more efficient to use [`positive`] instead. +/// +/// # Assumptions +/// * `a` has nonzero number of limbs pub fn assign( gate: &impl GateInstructions, ctx: &mut Context, - a: &OverflowInteger, + a: ProperUint, ) -> AssignedValue { - let k = a.limbs.len(); - debug_assert_ne!(k, 0); + assert!(!a.0.is_empty()); - let mut a_limbs = a.limbs.iter(); - let mut partial = gate.is_zero(ctx, *a_limbs.next().unwrap()); - for &a_limb in a_limbs { + let mut a_limbs = a.0.into_iter(); + let mut partial = gate.is_zero(ctx, a_limbs.next().unwrap()); + for a_limb in a_limbs { let limb_is_zero = gate.is_zero(ctx, a_limb); partial = gate.and(ctx, limb_is_zero, partial); } partial } +/// Returns 0 or 1. Returns 1 iff the limbs of `a` are identically zero. +/// This just calls [`assign`] on the limbs. +/// +/// It is almost always more efficient to use [`positive`] instead. pub fn crt( gate: &impl GateInstructions, ctx: &mut Context, - a: &CRTInteger, + a: ProperCrtUint, ) -> AssignedValue { - let out_trunc = assign::(gate, ctx, &a.truncation); - let out_native = gate.is_zero(ctx, a.native); - gate.and(ctx, out_trunc, out_native) + assign(gate, ctx, ProperUint(a.0.truncation.limbs)) } diff --git a/halo2-ecc/src/bigint/big_less_than.rs b/halo2-ecc/src/bigint/big_less_than.rs index 276de18c..01fe1eae 100644 --- a/halo2-ecc/src/bigint/big_less_than.rs +++ b/halo2-ecc/src/bigint/big_less_than.rs @@ -1,4 +1,4 @@ -use super::OverflowInteger; +use super::ProperUint; use halo2_base::{gates::RangeInstructions, utils::ScalarField, AssignedValue, Context}; // given OverflowInteger's `a` and `b` of the same shape, @@ -6,12 +6,12 @@ use halo2_base::{gates::RangeInstructions, utils::ScalarField, AssignedValue, Co pub fn assign( range: &impl RangeInstructions, ctx: &mut Context, - a: &OverflowInteger, - b: &OverflowInteger, + a: impl Into>, + b: impl Into>, limb_bits: usize, limb_base: F, ) -> AssignedValue { // a < b iff a - b has underflow - let (_, underflow) = super::sub::assign::(range, ctx, a, b, limb_bits, limb_base); + let (_, underflow) = super::sub::assign(range, ctx, a, b, limb_bits, limb_base); underflow } diff --git a/halo2-ecc/src/bigint/carry_mod.rs b/halo2-ecc/src/bigint/carry_mod.rs index 4b266cf3..a78fd32b 100644 --- a/halo2-ecc/src/bigint/carry_mod.rs +++ b/halo2-ecc/src/bigint/carry_mod.rs @@ -1,4 +1,5 @@ -use super::{check_carry_to_zero, CRTInteger, OverflowInteger}; +use std::{cmp::max, iter}; + use halo2_base::{ gates::{range::RangeStrategy, GateInstructions, RangeInstructions}, utils::{decompose_bigint, BigPrimeField}, @@ -8,7 +9,8 @@ use halo2_base::{ use num_bigint::BigInt; use num_integer::Integer; use num_traits::{One, Signed}; -use std::{cmp::max, iter}; + +use super::{check_carry_to_zero, CRTInteger, OverflowInteger, ProperCrtUint, ProperUint}; // Input `a` is `CRTInteger` with `a.truncation` of length `k` with "signed" limbs // Output is `out = a (mod modulus)` as CRTInteger with @@ -18,7 +20,10 @@ use std::{cmp::max, iter}; // `out.native = (a (mod modulus)) % (native_modulus::)` // We constrain `a = out + modulus * quotient` and range check `out` and `quotient` // -// Assumption: the leading two bits (in big endian) are 1, and `abs(a) <= 2^{n * k - 1 + F::NUM_BITS - 2}` (A weaker assumption is also enough, but this is good enough for forseeable use cases) +// Assumption: the leading two bits (in big endian) are 1, +/// # Assumptions +/// * abs(a) <= 2n * k - 1 + F::NUM_BITS - 2 (A weaker assumption is also enough, but this is good enough for forseeable use cases) +/// * `native_modulus::` requires *exactly* `k = a.limbs.len()` limbs to represent // This is currently optimized for limbs greater than 64 bits, so we need `F` to be a `BigPrimeField` // In the future we'll need a slightly different implementation for limbs that fit in 32 or 64 bits (e.g., `F` is Goldilocks) @@ -26,7 +31,7 @@ pub fn crt( range: &impl RangeInstructions, // chip: &BigIntConfig, ctx: &mut Context, - a: &CRTInteger, + a: CRTInteger, k_bits: usize, // = a.len().bits() modulus: &BigInt, mod_vec: &[F], @@ -34,7 +39,7 @@ pub fn crt( limb_bits: usize, limb_bases: &[F], limb_base_big: &BigInt, -) -> CRTInteger { +) -> ProperCrtUint { let n = limb_bits; let k = a.truncation.limbs.len(); let trunc_len = n * k; @@ -93,8 +98,8 @@ pub fn crt( // strategies where we carry out school-book multiplication in some form: // BigIntStrategy::Simple => { - for (i, (a_limb, (quot_v, out_v))) in - a.truncation.limbs.iter().zip(quot_vec.into_iter().zip(out_vec.into_iter())).enumerate() + for (i, ((a_limb, quot_v), out_v)) in + a.truncation.limbs.into_iter().zip(quot_vec).zip(out_vec).enumerate() { let (prod, new_quot_cell) = range.gate().inner_product_left_last( ctx, @@ -117,7 +122,7 @@ pub fn crt( ctx.assign_region( [ Constant(-F::one()), - Existing(*a_limb), + Existing(a_limb), Witness(temp1), Constant(F::one()), Witness(out_v), @@ -153,7 +158,7 @@ pub fn crt( range.range_check(ctx, quot_shift, limb_bits + 1); } - let check_overflow_int = OverflowInteger::construct( + let check_overflow_int = OverflowInteger::new( check_assigned, max(max(limb_bits, a.truncation.max_limb_bits) + 1, 2 * n + k_bits), ); @@ -169,21 +174,12 @@ pub fn crt( ); // Constrain `quot_native = sum_i quot_assigned[i] * 2^{n*i}` in `F` - let quot_native = OverflowInteger::::evaluate( - range.gate(), - ctx, - quot_assigned, - limb_bases.iter().copied(), - ); + let quot_native = + OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases); // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` - let out_native = OverflowInteger::::evaluate( - range.gate(), - ctx, - out_assigned.iter().copied(), - limb_bases.iter().copied(), - ); - + let out_native = + OverflowInteger::evaluate_native(ctx, range.gate(), out_assigned.clone(), limb_bases); // We save 1 cell by connecting `out_native` computation with the following: // Check `out + modulus * quotient - a = 0` in native field @@ -193,5 +189,9 @@ pub fn crt( [-1], // negative index because -1 relative offset is `out_native` assigned value ); - CRTInteger::construct(OverflowInteger::construct(out_assigned, limb_bits), out_native, out_val) + ProperCrtUint(CRTInteger::new( + ProperUint(out_assigned).into_overflow(limb_bits), + out_native, + out_val, + )) } diff --git a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs index db6f9084..6232cbdf 100644 --- a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs @@ -16,7 +16,7 @@ use std::{cmp::max, iter}; pub fn crt( range: &impl RangeInstructions, ctx: &mut Context, - a: &CRTInteger, + a: CRTInteger, k_bits: usize, // = a.len().bits() modulus: &BigInt, mod_vec: &[F], @@ -43,7 +43,7 @@ pub fn crt( // we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F` let (quot_val, _out_val) = a.value.div_mod_floor(modulus); - // only perform safety checks in display mode so we can turn them off in production + // only perform safety checks in debug mode debug_assert_eq!(_out_val, BigInt::zero()); debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits)); @@ -68,7 +68,7 @@ pub fn crt( // match chip.strategy { // BigIntStrategy::Simple => { - for (i, (a_limb, quot_v)) in a.truncation.limbs.iter().zip(quot_vec.into_iter()).enumerate() { + for (i, (a_limb, quot_v)) in a.truncation.limbs.into_iter().zip(quot_vec).enumerate() { let (prod, new_quot_cell) = range.gate().inner_product_left_last( ctx, quot_assigned.iter().map(|x| Existing(*x)).chain(iter::once(Witness(quot_v))), @@ -80,7 +80,7 @@ pub fn crt( // | prod | -1 | a | prod - a | let check_val = *prod.value() - a_limb.value(); let check_cell = ctx - .assign_region_last([Constant(-F::one()), Existing(*a_limb), Witness(check_val)], [-1]); + .assign_region_last([Constant(-F::one()), Existing(a_limb), Witness(check_val)], [-1]); quot_assigned.push(new_quot_cell); check_assigned.push(check_cell); @@ -100,7 +100,7 @@ pub fn crt( } let check_overflow_int = - OverflowInteger::construct(check_assigned, max(a.truncation.max_limb_bits, 2 * n + k_bits)); + OverflowInteger::new(check_assigned, max(a.truncation.max_limb_bits, 2 * n + k_bits)); // check that `modulus * quotient - a == 0 mod 2^{trunc_len}` after carry check_carry_to_zero::truncate::( @@ -113,12 +113,8 @@ pub fn crt( ); // Constrain `quot_native = sum_i out_assigned[i] * 2^{n*i}` in `F` - let quot_native = OverflowInteger::::evaluate( - range.gate(), - ctx, - quot_assigned, - limb_bases.iter().copied(), - ); + let quot_native = + OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases); // Check `0 + modulus * quotient - a = 0` in native field // | 0 | modulus | quotient | a | diff --git a/halo2-ecc/src/bigint/mod.rs b/halo2-ecc/src/bigint/mod.rs index f7f2886c..ea14b127 100644 --- a/halo2-ecc/src/bigint/mod.rs +++ b/halo2-ecc/src/bigint/mod.rs @@ -1,4 +1,3 @@ -use crate::halo2_proofs::circuit::Cell; use halo2_base::{ gates::flex_gate::GateInstructions, utils::{biguint_to_fe, decompose_biguint, fe_to_biguint, BigPrimeField, ScalarField}, @@ -24,8 +23,7 @@ pub mod select_by_indicator; pub mod sub; pub mod sub_no_carry; -#[derive(Clone, Debug, PartialEq)] -#[derive(Default)] +#[derive(Clone, Debug, PartialEq, Default)] pub enum BigIntStrategy { // use existing gates #[default] @@ -35,8 +33,6 @@ pub enum BigIntStrategy { // CustomVerticalShort, } - - #[derive(Clone, Debug)] pub struct OverflowInteger { pub limbs: Vec>, @@ -47,7 +43,7 @@ pub struct OverflowInteger { } impl OverflowInteger { - pub fn construct(limbs: Vec>, max_limb_bits: usize) -> Self { + pub fn new(limbs: Vec>, max_limb_bits: usize) -> Self { Self { limbs, max_limb_bits } } @@ -65,17 +61,57 @@ impl OverflowInteger { .fold(BigInt::zero(), |acc, acell| (acc << limb_bits) + fe_to_bigint(acell.value())) } - pub fn evaluate( - gate: &impl GateInstructions, + /// Computes `sum_i limbs[i] * limb_bases[i]` in native field `F`. + /// In practice assumes `limb_bases[i] = 2^{limb_bits * i}`. + pub fn evaluate_native( ctx: &mut Context, + gate: &impl GateInstructions, limbs: impl IntoIterator>, - limb_bases: impl IntoIterator, + limb_bases: &[F], ) -> AssignedValue { // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` - gate.inner_product(ctx, limbs, limb_bases.into_iter().map(|c| Constant(c))) + gate.inner_product(ctx, limbs, limb_bases.iter().map(|c| Constant(*c))) } } +/// Safe wrapper around a BigUint represented as a vector of limbs in **little endian**. +/// The underlying BigUint is represented by +/// sumi limbs\[i\] * 2limb_bits * i +/// +/// To save memory we do not store the `limb_bits` and it must be inferred from context. +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct ProperUint(pub(crate) Vec>); + +impl ProperUint { + pub fn limbs(&self) -> &[AssignedValue] { + self.0.as_slice() + } + + pub fn into_overflow(self, limb_bits: usize) -> OverflowInteger { + OverflowInteger::new(self.0, limb_bits) + } + + /// Computes `sum_i limbs[i] * limb_bases[i]` in native field `F`. + /// In practice assumes `limb_bases[i] = 2^{limb_bits * i}`. + /// + /// Assumes that `value` is the underlying BigUint value represented by `self`. + pub fn into_crt( + self, + ctx: &mut Context, + gate: &impl GateInstructions, + value: BigUint, + limb_bases: &[F], + limb_bits: usize, + ) -> ProperCrtUint { + // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` + let native = + OverflowInteger::evaluate_native(ctx, gate, self.0.iter().copied(), limb_bases); + ProperCrtUint(CRTInteger::new(self.into_overflow(limb_bits), native, value.into())) + } +} + +#[repr(transparent)] #[derive(Clone, Debug)] pub struct FixedOverflowInteger { pub limbs: Vec, @@ -101,9 +137,9 @@ impl FixedOverflowInteger { .fold(BigUint::zero(), |acc, x| (acc << limb_bits) + fe_to_biguint(x)) } - pub fn assign(self, ctx: &mut Context, limb_bits: usize) -> OverflowInteger { + pub fn assign(self, ctx: &mut Context) -> ProperUint { let assigned_limbs = self.limbs.into_iter().map(|limb| ctx.load_constant(limb)).collect(); - OverflowInteger::construct(assigned_limbs, limb_bits) + ProperUint(assigned_limbs) } /// only use case is when coeffs has only a single 1, rest are 0 @@ -123,7 +159,7 @@ impl FixedOverflowInteger { }) .collect(); - OverflowInteger::construct(out_limbs, limb_bits) + OverflowInteger::new(out_limbs, limb_bits) } } @@ -131,7 +167,7 @@ impl FixedOverflowInteger { pub struct CRTInteger { // keep track of an integer `a` using CRT as `a mod 2^t` and `a mod n` // where `t = truncation.limbs.len() * truncation.limb_bits` - // `n = modulus::` + // `n = modulus::` // `value` is the actual integer value we want to keep track of // we allow `value` to be a signed BigInt @@ -145,12 +181,21 @@ pub struct CRTInteger { pub value: BigInt, } +impl AsRef> for CRTInteger { + fn as_ref(&self) -> &CRTInteger { + self + } +} + +// Cloning all the time impacts readability so we'll just implement From<&T> for T +impl<'a, F: ScalarField> From<&'a CRTInteger> for CRTInteger { + fn from(x: &'a CRTInteger) -> Self { + x.clone() + } +} + impl CRTInteger { - pub fn construct( - truncation: OverflowInteger, - native: AssignedValue, - value: BigInt, - ) -> Self { + pub fn new(truncation: OverflowInteger, native: AssignedValue, value: BigInt) -> Self { Self { truncation, native, value } } @@ -163,6 +208,62 @@ impl CRTInteger { } } +/// Safe wrapper for representing a BigUint as a [`CRTInteger`] whose underlying BigUint value is in `[0, 2^t)` +/// where `t = truncation.limbs.len() * limb_bits`. This struct guarantees that +/// * each `truncation.limbs[i]` is ranged checked to be in `[0, 2^limb_bits)`, +/// * `native` is the evaluation of `sum_i truncation.limbs[i] * 2^{limb_bits * i} (mod modulus::)` in the native field `F` +/// * `value` is equal to `sum_i truncation.limbs[i] * 2^{limb_bits * i}` as integers +/// +/// Note this means `native` and `value` are completely determined by `truncation`. However, we still store them explicitly for convenience. +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct ProperCrtUint(pub(crate) CRTInteger); + +impl AsRef> for ProperCrtUint { + fn as_ref(&self) -> &CRTInteger { + &self.0 + } +} + +impl<'a, F: ScalarField> From<&'a ProperCrtUint> for ProperCrtUint { + fn from(x: &'a ProperCrtUint) -> Self { + x.clone() + } +} + +// cannot blanket implement From> for T because of Rust +impl From> for CRTInteger { + fn from(x: ProperCrtUint) -> Self { + x.0 + } +} + +impl<'a, F: ScalarField> From<&'a ProperCrtUint> for CRTInteger { + fn from(x: &'a ProperCrtUint) -> Self { + x.0.clone() + } +} + +impl From> for ProperUint { + fn from(x: ProperCrtUint) -> Self { + ProperUint(x.0.truncation.limbs) + } +} + +impl ProperCrtUint { + pub fn limbs(&self) -> &[AssignedValue] { + self.0.limbs() + } + + pub fn native(&self) -> &AssignedValue { + self.0.native() + } + + pub fn value(&self) -> BigUint { + self.0.value.to_biguint().expect("Value of proper uint should not be negative") + } +} + #[derive(Clone, Debug)] pub struct FixedCRTInteger { // keep track of an integer `a` using CRT as `a mod 2^t` and `a mod n` @@ -180,15 +281,8 @@ pub struct FixedCRTInteger { pub value: BigUint, } -#[derive(Clone, Debug)] -pub struct FixedAssignedCRTInteger { - pub truncation: FixedOverflowInteger, - pub limb_fixed_cells: Vec, - pub value: BigUint, -} - impl FixedCRTInteger { - pub fn construct(truncation: FixedOverflowInteger, value: BigUint) -> Self { + pub fn new(truncation: FixedOverflowInteger, value: BigUint) -> Self { Self { truncation, value } } @@ -204,9 +298,9 @@ impl FixedCRTInteger { ctx: &mut Context, limb_bits: usize, native_modulus: &BigUint, - ) -> CRTInteger { - let assigned_truncation = self.truncation.assign(ctx, limb_bits); + ) -> ProperCrtUint { + let assigned_truncation = self.truncation.assign(ctx).into_overflow(limb_bits); let assigned_native = ctx.load_constant(biguint_to_fe(&(&self.value % native_modulus))); - CRTInteger::construct(assigned_truncation, assigned_native, self.value.into()) + ProperCrtUint(CRTInteger::new(assigned_truncation, assigned_native, self.value.into())) } } diff --git a/halo2-ecc/src/bigint/mul_no_carry.rs b/halo2-ecc/src/bigint/mul_no_carry.rs index b6d5e745..aa174c3d 100644 --- a/halo2-ecc/src/bigint/mul_no_carry.rs +++ b/halo2-ecc/src/bigint/mul_no_carry.rs @@ -1,11 +1,16 @@ use super::{CRTInteger, OverflowInteger}; use halo2_base::{gates::GateInstructions, utils::ScalarField, Context, QuantumCell::Existing}; +/// # Assumptions +/// * `a` and `b` have the same number of limbs `k` +/// * `k` is nonzero +/// * `num_limbs_log2_ceil = log2_ceil(k)` +/// * `log2_ceil(k) + a.max_limb_bits + b.max_limb_bits <= F::NUM_BITS as usize - 2` pub fn truncate( gate: &impl GateInstructions, ctx: &mut Context, - a: &OverflowInteger, - b: &OverflowInteger, + a: OverflowInteger, + b: OverflowInteger, num_limbs_log2_ceil: usize, ) -> OverflowInteger { let k = a.limbs.len(); @@ -26,19 +31,19 @@ pub fn truncate( }) .collect(); - OverflowInteger::construct(out_limbs, num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits) + OverflowInteger::new(out_limbs, num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits) } pub fn crt( gate: &impl GateInstructions, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: CRTInteger, + b: CRTInteger, num_limbs_log2_ceil: usize, ) -> CRTInteger { - let out_trunc = truncate::(gate, ctx, &a.truncation, &b.truncation, num_limbs_log2_ceil); + let out_trunc = truncate::(gate, ctx, a.truncation, b.truncation, num_limbs_log2_ceil); let out_native = gate.mul(ctx, a.native, b.native); - let out_val = &a.value * &b.value; + let out_val = a.value * b.value; - CRTInteger::construct(out_trunc, out_native, out_val) + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/negative.rs b/halo2-ecc/src/bigint/negative.rs index 45a7d817..74e61da1 100644 --- a/halo2-ecc/src/bigint/negative.rs +++ b/halo2-ecc/src/bigint/negative.rs @@ -7,5 +7,5 @@ pub fn assign( a: OverflowInteger, ) -> OverflowInteger { let out_limbs = a.limbs.into_iter().map(|limb| gate.neg(ctx, limb)).collect(); - OverflowInteger::construct(out_limbs, a.max_limb_bits) + OverflowInteger::new(out_limbs, a.max_limb_bits) } diff --git a/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs b/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs index 579aff01..5c818453 100644 --- a/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs +++ b/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs @@ -5,36 +5,40 @@ use halo2_base::{ Context, QuantumCell::Constant, }; +use itertools::Itertools; use std::cmp::max; /// compute a * c + b = b + a * c +/// +/// # Assumptions +/// * `a, b` have same number of limbs +/// * Number of limbs is nonzero +/// * `c_log2_ceil = log2_ceil(c)` where `c` is the BigUint value of `c_f` // this is uniquely suited for our simple gate pub fn assign( gate: &impl GateInstructions, ctx: &mut Context, - a: &OverflowInteger, - b: &OverflowInteger, + a: OverflowInteger, + b: OverflowInteger, c_f: F, c_log2_ceil: usize, ) -> OverflowInteger { - debug_assert_eq!(a.limbs.len(), b.limbs.len()); - let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(&a_limb, &b_limb)| gate.mul_add(ctx, a_limb, Constant(c_f), b_limb)) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.mul_add(ctx, a_limb, Constant(c_f), b_limb)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits + c_log2_ceil, b.max_limb_bits) + 1) + OverflowInteger::new(out_limbs, max(a.max_limb_bits + c_log2_ceil, b.max_limb_bits) + 1) } /// compute a * c + b = b + a * c pub fn crt( gate: &impl GateInstructions, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: CRTInteger, + b: CRTInteger, c: i64, ) -> CRTInteger { debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); @@ -47,8 +51,8 @@ pub fn crt( (-F::from(c_abs), c_abs) }; - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation, c_f, log2_ceil(c_abs)); + let out_trunc = assign(gate, ctx, a.truncation, b.truncation, c_f, log2_ceil(c_abs)); let out_native = gate.mul_add(ctx, a.native, Constant(c_f), b.native); - let out_val = &a.value * c + &b.value; - CRTInteger::construct(out_trunc, out_native, out_val) + let out_val = a.value * c + b.value; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/scalar_mul_no_carry.rs b/halo2-ecc/src/bigint/scalar_mul_no_carry.rs index 60029e92..fdbc4058 100644 --- a/halo2-ecc/src/bigint/scalar_mul_no_carry.rs +++ b/halo2-ecc/src/bigint/scalar_mul_no_carry.rs @@ -14,13 +14,13 @@ pub fn assign( c_log2_ceil: usize, ) -> OverflowInteger { let out_limbs = a.limbs.into_iter().map(|limb| gate.mul(ctx, limb, Constant(c_f))).collect(); - OverflowInteger::construct(out_limbs, a.max_limb_bits + c_log2_ceil) + OverflowInteger::new(out_limbs, a.max_limb_bits + c_log2_ceil) } pub fn crt( gate: &impl GateInstructions, ctx: &mut Context, - a: &CRTInteger, + a: CRTInteger, c: i64, ) -> CRTInteger { let (c_f, c_abs) = if c >= 0 { @@ -31,15 +31,9 @@ pub fn crt( (-F::from(c_abs), c_abs) }; - let out_limbs = - a.truncation.limbs.iter().map(|limb| gate.mul(ctx, *limb, Constant(c_f))).collect(); - + let out_overflow = assign(gate, ctx, a.truncation, c_f, log2_ceil(c_abs)); let out_native = gate.mul(ctx, a.native, Constant(c_f)); - let out_val = &a.value * c; + let out_val = a.value * c; - CRTInteger::construct( - OverflowInteger::construct(out_limbs, a.truncation.max_limb_bits + log2_ceil(c_abs)), - out_native, - out_val, - ) + CRTInteger::new(out_overflow, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/select.rs b/halo2-ecc/src/bigint/select.rs index 1146eeb5..65fd7333 100644 --- a/halo2-ecc/src/bigint/select.rs +++ b/halo2-ecc/src/bigint/select.rs @@ -1,7 +1,11 @@ use super::{CRTInteger, OverflowInteger}; use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use itertools::Itertools; use std::cmp::max; +/// # Assumptions +/// * `a, b` have same number of limbs +/// * Number of limbs is nonzero pub fn assign( gate: &impl GateInstructions, ctx: &mut Context, @@ -9,39 +13,38 @@ pub fn assign( b: OverflowInteger, sel: AssignedValue, ) -> OverflowInteger { - debug_assert_eq!(a.limbs.len(), b.limbs.len()); let out_limbs = a .limbs .into_iter() - .zip(b.limbs.into_iter()) + .zip_eq(b.limbs) .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits)) + OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits)) } pub fn crt( gate: &impl GateInstructions, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: CRTInteger, + b: CRTInteger, sel: AssignedValue, ) -> CRTInteger { debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); let out_limbs = a .truncation .limbs - .iter() - .zip(b.truncation.limbs.iter()) - .map(|(&a_limb, &b_limb)| gate.select(ctx, a_limb, b_limb, sel)) + .into_iter() + .zip_eq(b.truncation.limbs) + .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel)) .collect(); - let out_trunc = OverflowInteger::construct( + let out_trunc = OverflowInteger::new( out_limbs, max(a.truncation.max_limb_bits, b.truncation.max_limb_bits), ); let out_native = gate.select(ctx, a.native, b.native, sel); - let out_val = if sel.value().is_zero_vartime() { b.value.clone() } else { a.value.clone() }; - CRTInteger::construct(out_trunc, out_native, out_val) + let out_val = if sel.value().is_zero_vartime() { b.value } else { a.value }; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/select_by_indicator.rs b/halo2-ecc/src/bigint/select_by_indicator.rs index 30aa5ab2..d1658d04 100644 --- a/halo2-ecc/src/bigint/select_by_indicator.rs +++ b/halo2-ecc/src/bigint/select_by_indicator.rs @@ -22,48 +22,48 @@ pub fn assign( let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.max_limb_bits)); - OverflowInteger::construct(out_limbs, max_limb_bits) + OverflowInteger::new(out_limbs, max_limb_bits) } /// only use case is when coeffs has only a single 1, rest are 0 pub fn crt( gate: &impl GateInstructions, ctx: &mut Context, - a: &[CRTInteger], + a: &[impl AsRef>], coeffs: &[AssignedValue], limb_bases: &[F], ) -> CRTInteger { assert_eq!(a.len(), coeffs.len()); - let k = a[0].truncation.limbs.len(); + let k = a[0].as_ref().truncation.limbs.len(); let out_limbs = (0..k) .map(|idx| { - let int_limbs = a.iter().map(|a| a.truncation.limbs[idx]); + let int_limbs = a.iter().map(|a| a.as_ref().truncation.limbs[idx]); gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied()) }) .collect(); - let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.truncation.max_limb_bits)); + let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.as_ref().truncation.max_limb_bits)); - let out_trunc = OverflowInteger::construct(out_limbs, max_limb_bits); + let out_trunc = OverflowInteger::new(out_limbs, max_limb_bits); let out_native = if a.len() > k { - OverflowInteger::::evaluate( - gate, + OverflowInteger::evaluate_native( ctx, + gate, out_trunc.limbs.iter().copied(), - limb_bases[..k].iter().copied(), + &limb_bases[..k], ) } else { - let a_native = a.iter().map(|x| x.native); + let a_native = a.iter().map(|x| x.as_ref().native); gate.select_by_indicator(ctx, a_native, coeffs.iter().copied()) }; let out_val = a.iter().zip(coeffs.iter()).fold(BigInt::zero(), |acc, (x, y)| { if y.value().is_zero_vartime() { acc } else { - x.value.clone() + x.as_ref().value.clone() } }); - CRTInteger::construct(out_trunc, out_native, out_val) + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/sub.rs b/halo2-ecc/src/bigint/sub.rs index 2d4d83ff..8b2263f9 100644 --- a/halo2-ecc/src/bigint/sub.rs +++ b/halo2-ecc/src/bigint/sub.rs @@ -1,28 +1,30 @@ -use super::{CRTInteger, OverflowInteger}; +use super::{CRTInteger, OverflowInteger, ProperCrtUint, ProperUint}; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, utils::ScalarField, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, }; +use itertools::Itertools; -/// Should only be called on integers a, b in proper representation with all limbs having at most `limb_bits` number of bits +/// # Assumptions +/// * Should only be called on integers a, b in proper representation with all limbs having at most `limb_bits` number of bits +/// * `a, b` have same nonzero number of limbs pub fn assign( range: &impl RangeInstructions, ctx: &mut Context, - a: &OverflowInteger, - b: &OverflowInteger, + a: impl Into>, + b: impl Into>, limb_bits: usize, limb_base: F, ) -> (OverflowInteger, AssignedValue) { - debug_assert!(a.max_limb_bits <= limb_bits); - debug_assert!(b.max_limb_bits <= limb_bits); - debug_assert_eq!(a.limbs.len(), b.limbs.len()); - let k = a.limbs.len(); + let a = a.into(); + let b = b.into(); + let k = a.0.len(); let mut out_limbs = Vec::with_capacity(k); let mut borrow: Option> = None; - for (&a_limb, &b_limb) in a.limbs.iter().zip(b.limbs.iter()) { + for (a_limb, b_limb) in a.0.into_iter().zip_eq(b.0) { let (bottom, lt) = match borrow { None => { let lt = range.is_less_than(ctx, a_limb, b_limb, limb_bits); @@ -54,21 +56,24 @@ pub fn assign( out_limbs.push(out_limb); borrow = Some(lt); } - (OverflowInteger::construct(out_limbs, limb_bits), borrow.unwrap()) + (OverflowInteger::new(out_limbs, limb_bits), borrow.unwrap()) } // returns (a-b, underflow), where underflow is nonzero iff a < b +/// # Assumptions +/// * `a, b` are proper CRT representations of integers with the same number of limbs pub fn crt( range: &impl RangeInstructions, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: ProperCrtUint, + b: ProperCrtUint, limb_bits: usize, limb_base: F, ) -> (CRTInteger, AssignedValue) { - let (out_trunc, underflow) = - assign::(range, ctx, &a.truncation, &b.truncation, limb_bits, limb_base); - let out_native = range.gate().sub(ctx, a.native, b.native); - let out_val = &a.value - &b.value; - (CRTInteger::construct(out_trunc, out_native, out_val), underflow) + let out_native = range.gate().sub(ctx, a.0.native, b.0.native); + let a_limbs = ProperUint(a.0.truncation.limbs); + let b_limbs = ProperUint(b.0.truncation.limbs); + let (out_trunc, underflow) = assign(range, ctx, a_limbs, b_limbs, limb_bits, limb_base); + let out_val = a.0.value - b.0.value; + (CRTInteger::new(out_trunc, out_native, out_val), underflow) } diff --git a/halo2-ecc/src/bigint/sub_no_carry.rs b/halo2-ecc/src/bigint/sub_no_carry.rs index ae4bb8a3..4e8867c0 100644 --- a/halo2-ecc/src/bigint/sub_no_carry.rs +++ b/halo2-ecc/src/bigint/sub_no_carry.rs @@ -1,32 +1,34 @@ use super::{CRTInteger, OverflowInteger}; use halo2_base::{gates::GateInstructions, utils::ScalarField, Context}; +use itertools::Itertools; use std::cmp::max; +/// # Assumptions +/// * `a, b` have same number of limbs pub fn assign( gate: &impl GateInstructions, ctx: &mut Context, - a: &OverflowInteger, - b: &OverflowInteger, + a: OverflowInteger, + b: OverflowInteger, ) -> OverflowInteger { - debug_assert_eq!(a.limbs.len(), b.limbs.len()); let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(&a_limb, &b_limb)| gate.sub(ctx, a_limb, b_limb)) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.sub(ctx, a_limb, b_limb)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) + OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) } pub fn crt( gate: &impl GateInstructions, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: CRTInteger, + b: CRTInteger, ) -> CRTInteger { - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation); + let out_trunc = assign(gate, ctx, a.truncation, b.truncation); let out_native = gate.sub(ctx, a.native, b.native); - let out_val = &a.value - &b.value; - CRTInteger::construct(out_trunc, out_native, out_val) + let out_val = a.value - b.value; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bn254/final_exp.rs b/halo2-ecc/src/bn254/final_exp.rs index 9ab45daa..7959142e 100644 --- a/halo2-ecc/src/bn254/final_exp.rs +++ b/halo2-ecc/src/bn254/final_exp.rs @@ -1,18 +1,13 @@ -use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint}; +use super::{Fp12Chip, Fp2Chip, FpChip, FqPoint}; use crate::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Fq, Fq2, BN_X, FROBENIUS_COEFF_FQ12_C1}, }; use crate::{ ecc::get_naf, - fields::{fp12::mul_no_carry_w6, FieldChip, FieldExtPoint, PrimeField}, -}; -use halo2_base::{ - gates::GateInstructions, - utils::{fe_to_biguint, modulus}, - Context, - QuantumCell::Constant, + fields::{fp12::mul_no_carry_w6, vector::FieldVector, FieldChip, PrimeField}, }; +use halo2_base::{gates::GateInstructions, utils::modulus, Context, QuantumCell::Constant}; use num_bigint::BigUint; const XI_0: i64 = 9; @@ -28,46 +23,48 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { ) -> >::FieldPoint { assert_eq!(modulus::() % 4u64, BigUint::from(3u64)); assert_eq!(modulus::() % 6u64, BigUint::from(1u64)); - assert_eq!(a.coeffs.len(), 12); + assert_eq!(a.0.len(), 12); let pow = power % 12; let mut out_fp2 = Vec::with_capacity(6); - let fp2_chip = Fp2Chip::::new(self.fp_chip); + let fp_chip = self.fp_chip(); + let fp2_chip = Fp2Chip::::new(fp_chip); for i in 0..6 { let frob_coeff = FROBENIUS_COEFF_FQ12_C1[pow].pow_vartime([i as u64]); // possible optimization (not implemented): load `frob_coeff` as we multiply instead of loading first // frobenius map is used infrequently so this is a small optimization - let mut a_fp2 = - FieldExtPoint::construct(vec![a.coeffs[i].clone(), a.coeffs[i + 6].clone()]); + let mut a_fp2 = FieldVector(vec![a[i].clone(), a[i + 6].clone()]); if pow % 2 != 0 { - a_fp2 = fp2_chip.conjugate(ctx, &a_fp2); + a_fp2 = fp2_chip.conjugate(ctx, a_fp2); } // if `frob_coeff` is in `Fp` and not just `Fp2`, then we can be more efficient in multiplication if frob_coeff == Fq2::one() { out_fp2.push(a_fp2); } else if frob_coeff.c1 == Fq::zero() { - let frob_fixed = fp2_chip.fp_chip.load_constant(ctx, fe_to_biguint(&frob_coeff.c0)); + let frob_fixed = fp_chip.load_constant(ctx, frob_coeff.c0); { - let out_nocarry = fp2_chip.fp_mul_no_carry(ctx, &a_fp2, &frob_fixed); - out_fp2.push(fp2_chip.carry_mod(ctx, &out_nocarry)); + let out_nocarry = fp2_chip.0.fp_mul_no_carry(ctx, a_fp2, frob_fixed); + out_fp2.push(fp2_chip.carry_mod(ctx, out_nocarry)); } } else { let frob_fixed = fp2_chip.load_constant(ctx, frob_coeff); - out_fp2.push(fp2_chip.mul(ctx, &a_fp2, &frob_fixed)); + out_fp2.push(fp2_chip.mul(ctx, a_fp2, frob_fixed)); } } let out_coeffs = out_fp2 .iter() - .map(|x| x.coeffs[0].clone()) - .chain(out_fp2.iter().map(|x| x.coeffs[1].clone())) + .map(|x| x[0].clone()) + .chain(out_fp2.iter().map(|x| x[1].clone())) .collect(); - FieldExtPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // exp is in little-endian + /// # Assumptions + /// * `a` is nonzero field point pub fn pow( &self, ctx: &mut Context, @@ -86,7 +83,11 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { if z != 0 { assert!(z == 1 || z == -1); if is_started { - res = if z == 1 { self.mul(ctx, &res, a) } else { self.divide(ctx, &res, a) }; + res = if z == 1 { + self.mul(ctx, &res, a) + } else { + self.divide_unsafe(ctx, &res, a) + }; } else { assert_eq!(z, 1); is_started = true; @@ -106,14 +107,12 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { /// in = g0 + g2 w + g4 w^2 + g1 w^3 + g3 w^4 + g5 w^5 where g_i = g_i0 + g_i1 * u are elements of Fp2 /// out = Compress(in) = [ g2, g3, g4, g5 ] - pub fn cyclotomic_compress( - &self, - a: &FieldExtPoint>, - ) -> Vec>> { - let g2 = FieldExtPoint::construct(vec![a.coeffs[1].clone(), a.coeffs[1 + 6].clone()]); - let g3 = FieldExtPoint::construct(vec![a.coeffs[4].clone(), a.coeffs[4 + 6].clone()]); - let g4 = FieldExtPoint::construct(vec![a.coeffs[2].clone(), a.coeffs[2 + 6].clone()]); - let g5 = FieldExtPoint::construct(vec![a.coeffs[5].clone(), a.coeffs[5 + 6].clone()]); + pub fn cyclotomic_compress(&self, a: &FqPoint) -> Vec> { + let a = &a.0; + let g2 = FieldVector(vec![a[1].clone(), a[1 + 6].clone()]); + let g3 = FieldVector(vec![a[4].clone(), a[4 + 6].clone()]); + let g4 = FieldVector(vec![a[2].clone(), a[2 + 6].clone()]); + let g5 = FieldVector(vec![a[5].clone(), a[5 + 6].clone()]); vec![g2, g3, g4, g5] } @@ -132,13 +131,14 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { pub fn cyclotomic_decompress( &self, ctx: &mut Context, - compression: Vec>>, - ) -> FieldExtPoint> { - let [g2, g3, g4, g5]: [FieldExtPoint>; 4] = compression.try_into().unwrap(); + compression: Vec>, + ) -> FqPoint { + let [g2, g3, g4, g5]: [_; 4] = compression.try_into().unwrap(); - let fp2_chip = Fp2Chip::::new(self.fp_chip); + let fp_chip = self.fp_chip(); + let fp2_chip = Fp2Chip::::new(fp_chip); let g5_sq = fp2_chip.mul_no_carry(ctx, &g5, &g5); - let g5_sq_c = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &g5_sq); + let g5_sq_c = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, g5_sq); let g4_sq = fp2_chip.mul_no_carry(ctx, &g4, &g4); let g4_sq_3 = fp2_chip.scalar_mul_no_carry(ctx, &g4_sq, 3); @@ -148,15 +148,15 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { g1_num = fp2_chip.sub_no_carry(ctx, &g1_num, &g3_2); // can divide without carrying g1_num or g1_denom (I think) let g2_4 = fp2_chip.scalar_mul_no_carry(ctx, &g2, 4); - let g1_1 = fp2_chip.divide(ctx, &g1_num, &g2_4); + let g1_1 = fp2_chip.divide_unsafe(ctx, &g1_num, &g2_4); let g4_g5 = fp2_chip.mul_no_carry(ctx, &g4, &g5); let g1_num = fp2_chip.scalar_mul_no_carry(ctx, &g4_g5, 2); - let g1_0 = fp2_chip.divide(ctx, &g1_num, &g3); + let g1_0 = fp2_chip.divide_unsafe(ctx, &g1_num, &g3); let g2_is_zero = fp2_chip.is_zero(ctx, &g2); // resulting `g1` is already in "carried" format (witness is in `[0, p)`) - let g1 = fp2_chip.select(ctx, &g1_0, &g1_1, g2_is_zero); + let g1 = fp2_chip.0.select(ctx, g1_0, g1_1, g2_is_zero); // share the computation of 2 g1^2 between the two cases let g1_sq = fp2_chip.mul_no_carry(ctx, &g1, &g1); @@ -166,26 +166,26 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { let g3_g4 = fp2_chip.mul_no_carry(ctx, &g3, &g4); let g3_g4_3 = fp2_chip.scalar_mul_no_carry(ctx, &g3_g4, 3); let temp = fp2_chip.add_no_carry(ctx, &g1_sq_2, &g2_g5); - let temp = fp2_chip.select(ctx, &g1_sq_2, &temp, g2_is_zero); + let temp = fp2_chip.0.select(ctx, g1_sq_2, temp, g2_is_zero); let temp = fp2_chip.sub_no_carry(ctx, &temp, &g3_g4_3); - let mut g0 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &temp); + let mut g0 = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, temp); // compute `g0 + 1` - g0.coeffs[0].truncation.limbs[0] = - fp2_chip.gate().add(ctx, g0.coeffs[0].truncation.limbs[0], Constant(F::one())); - g0.coeffs[0].native = fp2_chip.gate().add(ctx, g0.coeffs[0].native, Constant(F::one())); - g0.coeffs[0].truncation.max_limb_bits += 1; - g0.coeffs[0].value += 1usize; + g0[0].truncation.limbs[0] = + fp2_chip.gate().add(ctx, g0[0].truncation.limbs[0], Constant(F::one())); + g0[0].native = fp2_chip.gate().add(ctx, g0[0].native, Constant(F::one())); + g0[0].truncation.max_limb_bits += 1; + g0[0].value += 1usize; // finally, carry g0 - g0 = fp2_chip.carry_mod(ctx, &g0); + let g0 = fp2_chip.carry_mod(ctx, g0); - let mut g0 = g0.coeffs.into_iter(); - let mut g1 = g1.coeffs.into_iter(); - let mut g2 = g2.coeffs.into_iter(); - let mut g3 = g3.coeffs.into_iter(); - let mut g4 = g4.coeffs.into_iter(); - let mut g5 = g5.coeffs.into_iter(); + let mut g0 = g0.into_iter(); + let mut g1 = g1.into_iter(); + let mut g2 = g2.into_iter(); + let mut g3 = g3.into_iter(); + let mut g4 = g4.into_iter(); + let mut g5 = g5.into_iter(); let mut out_coeffs = Vec::with_capacity(12); for _ in 0..2 { @@ -198,7 +198,7 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { g5.next().unwrap(), ]); } - FieldExtPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // input is [g2, g3, g4, g5] = C(g) in compressed format of `cyclotomic_compress` @@ -216,58 +216,56 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { pub fn cyclotomic_square( &self, ctx: &mut Context, - compression: &[FieldExtPoint>], - ) -> Vec>> { + compression: &[FqPoint], + ) -> Vec> { assert_eq!(compression.len(), 4); let g2 = &compression[0]; let g3 = &compression[1]; let g4 = &compression[2]; let g5 = &compression[3]; - let fp2_chip = Fp2Chip::::new(self.fp_chip); + let fp_chip = self.fp_chip(); + let fp2_chip = Fp2Chip::::new(fp_chip); let g2_plus_g3 = fp2_chip.add_no_carry(ctx, g2, g3); - let cg3 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, g3); + let cg3 = mul_no_carry_w6::, XI_0>(fp_chip, ctx, g3.into()); let g2_plus_cg3 = fp2_chip.add_no_carry(ctx, g2, &cg3); let a23 = fp2_chip.mul_no_carry(ctx, &g2_plus_g3, &g2_plus_cg3); let g4_plus_g5 = fp2_chip.add_no_carry(ctx, g4, g5); - let cg5 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, g5); + let cg5 = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, g5.into()); let g4_plus_cg5 = fp2_chip.add_no_carry(ctx, g4, &cg5); let a45 = fp2_chip.mul_no_carry(ctx, &g4_plus_g5, &g4_plus_cg5); let b23 = fp2_chip.mul_no_carry(ctx, g2, g3); let b45 = fp2_chip.mul_no_carry(ctx, g4, g5); - let b45_c = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &b45); + let b45_c = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, b45.clone()); let mut temp = fp2_chip.scalar_mul_and_add_no_carry(ctx, &b45_c, g2, 3); let h2 = fp2_chip.scalar_mul_no_carry(ctx, &temp, 2); - temp = fp2_chip.add_no_carry(ctx, &b45_c, &b45); - temp = fp2_chip.sub_no_carry(ctx, &a45, &temp); - temp = fp2_chip.scalar_mul_no_carry(ctx, &temp, 3); - let h3 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g3, &temp, -2); + temp = fp2_chip.add_no_carry(ctx, b45_c, b45); + temp = fp2_chip.sub_no_carry(ctx, &a45, temp); + temp = fp2_chip.scalar_mul_no_carry(ctx, temp, 3); + let h3 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g3, temp, -2); const XI0_PLUS_1: i64 = XI_0 + 1; // (c + 1) = (XI_0 + 1) + u - temp = mul_no_carry_w6::, XI0_PLUS_1>(fp2_chip.fp_chip, ctx, &b23); - temp = fp2_chip.sub_no_carry(ctx, &a23, &temp); - temp = fp2_chip.scalar_mul_no_carry(ctx, &temp, 3); - let h4 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g4, &temp, -2); + temp = mul_no_carry_w6::, XI0_PLUS_1>(fp_chip, ctx, b23.clone()); + temp = fp2_chip.sub_no_carry(ctx, &a23, temp); + temp = fp2_chip.scalar_mul_no_carry(ctx, temp, 3); + let h4 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g4, temp, -2); - temp = fp2_chip.scalar_mul_and_add_no_carry(ctx, &b23, g5, 3); - let h5 = fp2_chip.scalar_mul_no_carry(ctx, &temp, 2); + temp = fp2_chip.scalar_mul_and_add_no_carry(ctx, b23, g5, 3); + let h5 = fp2_chip.scalar_mul_no_carry(ctx, temp, 2); - [h2, h3, h4, h5].iter().map(|h| fp2_chip.carry_mod(ctx, h)).collect() + [h2, h3, h4, h5].into_iter().map(|h| fp2_chip.carry_mod(ctx, h)).collect() } // exp is in little-endian - pub fn cyclotomic_pow( - &self, - ctx: &mut Context, - a: FieldExtPoint>, - exp: Vec, - ) -> FieldExtPoint> { + /// # Assumptions + /// * `a` is a nonzero element in the cyclotomic subgroup + pub fn cyclotomic_pow(&self, ctx: &mut Context, a: FqPoint, exp: Vec) -> FqPoint { let mut compression = self.cyclotomic_compress(&a); let mut out = None; let mut is_started = false; @@ -281,7 +279,11 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { assert!(z == 1 || z == -1); if is_started { let mut res = self.cyclotomic_decompress(ctx, compression); - res = if z == 1 { self.mul(ctx, &res, &a) } else { self.divide(ctx, &res, &a) }; + res = if z == 1 { + self.mul(ctx, &res, &a) + } else { + self.divide_unsafe(ctx, &res, &a) + }; // compression is free, so it doesn't hurt (except possibly witness generation runtime) to do it // TODO: alternatively we go from small bits to large to avoid this compression compression = self.cyclotomic_compress(&res); @@ -318,7 +320,7 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { let mp2_mp3 = self.mul(ctx, &mp2, &mp3); let y0 = self.mul(ctx, &mp, &mp2_mp3); // y1 = 1/m, inverse = frob(6) = conjugation in cyclotomic subgroup - let y1 = self.conjugate(ctx, &m); + let y1 = self.conjugate(ctx, m.clone()); // m^x let mx = self.cyclotomic_pow(ctx, m, vec![BN_X]); @@ -333,20 +335,20 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { let y2 = self.frobenius_map(ctx, &mx2, 2); // m^{x^3} // y5 = 1/mx2 - let y5 = self.conjugate(ctx, &mx2); + let y5 = self.conjugate(ctx, mx2.clone()); let mx3 = self.cyclotomic_pow(ctx, mx2, vec![BN_X]); // (m^{x^3})^p let mx3p = self.frobenius_map(ctx, &mx3, 1); // y3 = 1/mxp - let y3 = self.conjugate(ctx, &mxp); + let y3 = self.conjugate(ctx, mxp); // y4 = 1/(mx * mx2p) let mx_mx2p = self.mul(ctx, &mx, &mx2p); - let y4 = self.conjugate(ctx, &mx_mx2p); + let y4 = self.conjugate(ctx, mx_mx2p); // y6 = 1/(mx3 * mx3p) let mx3_mx3p = self.mul(ctx, &mx3, &mx3p); - let y6 = self.conjugate(ctx, &mx3_mx3p); + let y6 = self.conjugate(ctx, mx3_mx3p); // out = y0 * y1^2 * y2^6 * y3^12 * y4^18 * y5^30 * y6^36 // we compute this using the vectorial addition chain from p. 6 of https://eprint.iacr.org/2008/490.pdf @@ -368,14 +370,16 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { } // out = in^{ (q^6 - 1)*(q^2 + 1) } + /// # Assumptions + /// * `a` is nonzero field point pub fn easy_part( &self, ctx: &mut Context, - a: &>::FieldPoint, + a: >::FieldPoint, ) -> >::FieldPoint { // a^{q^6} = conjugate of a - let f1 = self.conjugate(ctx, a); - let f2 = self.divide(ctx, &f1, a); + let f1 = self.conjugate(ctx, a.clone()); + let f2 = self.divide_unsafe(ctx, &f1, a); let f3 = self.frobenius_map(ctx, &f2, 2); self.mul(ctx, &f3, &f2) } @@ -384,7 +388,7 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { pub fn final_exp( &self, ctx: &mut Context, - a: &>::FieldPoint, + a: >::FieldPoint, ) -> >::FieldPoint { let f0 = self.easy_part(ctx, a); let f = self.hard_part_BN(ctx, f0); diff --git a/halo2-ecc/src/bn254/mod.rs b/halo2-ecc/src/bn254/mod.rs index 6640f729..deed3c4d 100644 --- a/halo2-ecc/src/bn254/mod.rs +++ b/halo2-ecc/src/bn254/mod.rs @@ -1,15 +1,14 @@ +use crate::bigint::ProperCrtUint; +use crate::fields::vector::FieldVector; +use crate::fields::{fp, fp12, fp2}; use crate::halo2_proofs::halo2curves::bn256::{Fq, Fq12, Fq2}; -use crate::{ - bigint::CRTInteger, - fields::{fp, fp12, fp2, FieldExtPoint}, -}; pub mod final_exp; pub mod pairing; pub type FpChip<'range, F> = fp::FpChip<'range, F, Fq>; -pub type FpPoint = CRTInteger; -pub type FqPoint = FieldExtPoint>; +pub type FpPoint = ProperCrtUint; +pub type FqPoint = FieldVector>; pub type Fp2Chip<'chip, F> = fp2::Fp2Chip<'chip, F, FpChip<'chip, F>, Fq2>; pub type Fp12Chip<'chip, F> = fp12::Fp12Chip<'chip, F, FpChip<'chip, F>, Fq12, 9>; diff --git a/halo2-ecc/src/bn254/pairing.rs b/halo2-ecc/src/bn254/pairing.rs index cc4c9a87..e25f066a 100644 --- a/halo2-ecc/src/bn254/pairing.rs +++ b/halo2-ecc/src/bn254/pairing.rs @@ -1,15 +1,15 @@ #![allow(non_snake_case)] -use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint, FqPoint}; +use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint, Fq, FqPoint}; +use crate::fields::vector::FieldVector; use crate::halo2_proofs::halo2curves::bn256::{ G1Affine, G2Affine, FROBENIUS_COEFF_FQ12_C1, SIX_U_PLUS_2_NAF, }; use crate::{ ecc::{EcPoint, EccChip}, fields::fp12::mul_no_carry_w6, - fields::{FieldChip, FieldExtPoint, PrimeField}, + fields::{FieldChip, PrimeField}, }; use halo2_base::Context; -use num_bigint::BigUint; const XI_0: i64 = 9; @@ -30,25 +30,25 @@ pub fn sparse_line_function_unequal( let (x_1, y_1) = (&Q.0.x, &Q.0.y); let (x_2, y_2) = (&Q.1.x, &Q.1.y); let (X, Y) = (&P.x, &P.y); - assert_eq!(x_1.coeffs.len(), 2); - assert_eq!(y_1.coeffs.len(), 2); - assert_eq!(x_2.coeffs.len(), 2); - assert_eq!(y_2.coeffs.len(), 2); + assert_eq!(x_1.0.len(), 2); + assert_eq!(y_1.0.len(), 2); + assert_eq!(x_2.0.len(), 2); + assert_eq!(y_2.0.len(), 2); let y1_minus_y2 = fp2_chip.sub_no_carry(ctx, y_1, y_2); let x2_minus_x1 = fp2_chip.sub_no_carry(ctx, x_2, x_1); let x1y2 = fp2_chip.mul_no_carry(ctx, x_1, y_2); let x2y1 = fp2_chip.mul_no_carry(ctx, x_2, y_1); - let out3 = fp2_chip.fp_mul_no_carry(ctx, &y1_minus_y2, X); - let out2 = fp2_chip.fp_mul_no_carry(ctx, &x2_minus_x1, Y); + let out3 = fp2_chip.0.fp_mul_no_carry(ctx, y1_minus_y2, X); + let out2 = fp2_chip.0.fp_mul_no_carry(ctx, x2_minus_x1, Y); let out5 = fp2_chip.sub_no_carry(ctx, &x1y2, &x2y1); // so far we have not "carried mod p" for any of the outputs // we do this below - vec![None, None, Some(out2), Some(out3), None, Some(out5)] - .iter() - .map(|option_nc| option_nc.as_ref().map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) + [None, None, Some(out2), Some(out3), None, Some(out5)] + .into_iter() + .map(|option_nc| option_nc.map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) .collect() } @@ -67,8 +67,8 @@ pub fn sparse_line_function_equal( P: &EcPoint>, ) -> Vec>> { let (x, y) = (&Q.x, &Q.y); - assert_eq!(x.coeffs.len(), 2); - assert_eq!(y.coeffs.len(), 2); + assert_eq!(x.0.len(), 2); + assert_eq!(y.0.len(), 2); let x_sq = fp2_chip.mul(ctx, x, x); @@ -77,19 +77,19 @@ pub fn sparse_line_function_equal( let y_sq = fp2_chip.mul_no_carry(ctx, y, y); let two_y_sq = fp2_chip.scalar_mul_no_carry(ctx, &y_sq, 2); let out0_left = fp2_chip.sub_no_carry(ctx, &three_x_cu, &two_y_sq); - let out0 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &out0_left); + let out0 = mul_no_carry_w6::<_, _, XI_0>(fp2_chip.fp_chip(), ctx, out0_left); - let x_sq_Px = fp2_chip.fp_mul_no_carry(ctx, &x_sq, &P.x); - let out4 = fp2_chip.scalar_mul_no_carry(ctx, &x_sq_Px, -3); + let x_sq_Px = fp2_chip.0.fp_mul_no_carry(ctx, x_sq, &P.x); + let out4 = fp2_chip.scalar_mul_no_carry(ctx, x_sq_Px, -3); - let y_Py = fp2_chip.fp_mul_no_carry(ctx, y, &P.y); + let y_Py = fp2_chip.0.fp_mul_no_carry(ctx, y.clone(), &P.y); let out3 = fp2_chip.scalar_mul_no_carry(ctx, &y_Py, 2); // so far we have not "carried mod p" for any of the outputs // we do this below - vec![Some(out0), None, None, Some(out3), Some(out4), None] - .iter() - .map(|option_nc| option_nc.as_ref().map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) + [Some(out0), None, None, Some(out3), Some(out4), None] + .into_iter() + .map(|option_nc| option_nc.map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) .collect() } @@ -99,16 +99,16 @@ pub fn sparse_fp12_multiply( fp2_chip: &Fp2Chip, ctx: &mut Context, a: &FqPoint, - b_fp2_coeffs: &Vec>>, -) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 12); + b_fp2_coeffs: &[Option>], +) -> FqPoint { + assert_eq!(a.0.len(), 12); assert_eq!(b_fp2_coeffs.len(), 6); let mut a_fp2_coeffs = Vec::with_capacity(6); for i in 0..6 { - a_fp2_coeffs.push(FqPoint::construct(vec![a.coeffs[i].clone(), a.coeffs[i + 6].clone()])); + a_fp2_coeffs.push(FieldVector(vec![a[i].clone(), a[i + 6].clone()])); } // a * b as element of Fp2[w] without evaluating w^6 = (XI_0 + u) - let mut prod_2d: Vec>>> = vec![None; 11]; + let mut prod_2d = vec![None; 11]; for i in 0..6 { for j in 0..6 { prod_2d[i + j] = @@ -133,7 +133,7 @@ pub fn sparse_fp12_multiply( let prod_nocarry = if i != 5 { let eval_w6 = prod_2d[i + 6] .as_ref() - .map(|a| mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, a)); + .map(|a| mul_no_carry_w6::<_, _, XI_0>(fp2_chip.fp_chip(), ctx, a.clone())); match (prod_2d[i].as_ref(), eval_w6) { (None, b) => b.unwrap(), // Our current use cases of 235 and 034 sparse multiplication always result in non-None value (Some(a), None) => a.clone(), @@ -142,18 +142,18 @@ pub fn sparse_fp12_multiply( } else { prod_2d[i].clone().unwrap() }; - let prod = fp2_chip.carry_mod(ctx, &prod_nocarry); + let prod = fp2_chip.carry_mod(ctx, prod_nocarry); out_fp2.push(prod); } let mut out_coeffs = Vec::with_capacity(12); for fp2_coeff in &out_fp2 { - out_coeffs.push(fp2_coeff.coeffs[0].clone()); + out_coeffs.push(fp2_coeff[0].clone()); } for fp2_coeff in &out_fp2 { - out_coeffs.push(fp2_coeff.coeffs[1].clone()); + out_coeffs.push(fp2_coeff[1].clone()); } - FqPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // Input: @@ -221,7 +221,7 @@ pub fn miller_loop_BN( } let last_index = i; - let neg_Q = ecc_chip.negate(ctx, Q); + let neg_Q = ecc_chip.negate(ctx, Q.clone()); assert!(pseudo_binary_encoding[i] == 1 || pseudo_binary_encoding[i] == -1); let mut R = if pseudo_binary_encoding[i] == 1 { Q.clone() } else { neg_Q.clone() }; i -= 1; @@ -230,28 +230,29 @@ pub fn miller_loop_BN( let sparse_f = sparse_line_function_equal::(ecc_chip.field_chip(), ctx, &R, P); assert_eq!(sparse_f.len(), 6); - let zero_fp = ecc_chip.field_chip.fp_chip.load_constant(ctx, BigUint::from(0u64)); + let fp_chip = ecc_chip.field_chip.fp_chip(); + let zero_fp = fp_chip.load_constant(ctx, Fq::zero()); let mut f_coeffs = Vec::with_capacity(12); for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[0].clone()); + f_coeffs.push(fp2_point[0].clone()); } else { f_coeffs.push(zero_fp.clone()); } } for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[1].clone()); + f_coeffs.push(fp2_point[1].clone()); } else { f_coeffs.push(zero_fp.clone()); } } - let mut f = FqPoint::construct(f_coeffs); + let mut f = FieldVector(f_coeffs); + let fp12_chip = Fp12Chip::::new(fp_chip); loop { if i != last_index - 1 { - let fp12_chip = Fp12Chip::::new(ecc_chip.field_chip.fp_chip); let f_sq = fp12_chip.mul(ctx, &f, &f); f = fp12_multiply_with_line_equal::(ecc_chip.field_chip(), ctx, &f_sq, &R, P); } @@ -308,29 +309,30 @@ pub fn multi_miller_loop_BN( let neg_b = pairs.iter().map(|pair| ecc_chip.negate(ctx, pair.1)).collect::>(); + let fp_chip = ecc_chip.field_chip.fp_chip(); // initialize the first line function into Fq12 point let mut f = { let sparse_f = sparse_line_function_equal::(ecc_chip.field_chip(), ctx, pairs[0].1, pairs[0].0); assert_eq!(sparse_f.len(), 6); - let zero_fp = ecc_chip.field_chip.fp_chip.load_constant(ctx, BigUint::from(0u64)); + let zero_fp = fp_chip.load_constant(ctx, Fq::zero()); let mut f_coeffs = Vec::with_capacity(12); for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[0].clone()); + f_coeffs.push(fp2_point[0].clone()); } else { f_coeffs.push(zero_fp.clone()); } } for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[1].clone()); + f_coeffs.push(fp2_point[1].clone()); } else { f_coeffs.push(zero_fp.clone()); } } - FqPoint::construct(f_coeffs) + FieldVector(f_coeffs) }; for &(a, b) in pairs.iter().skip(1) { f = fp12_multiply_with_line_equal::(ecc_chip.field_chip(), ctx, &f, b, a); @@ -338,7 +340,7 @@ pub fn multi_miller_loop_BN( i -= 1; let mut r = pairs.iter().map(|pair| pair.1.clone()).collect::>(); - let fp12_chip = Fp12Chip::::new(ecc_chip.field_chip.fp_chip); + let fp12_chip = Fp12Chip::::new(fp_chip); loop { if i != last_index - 1 { f = fp12_chip.mul(ctx, &f, &f); @@ -347,7 +349,7 @@ pub fn multi_miller_loop_BN( } } for r in r.iter_mut() { - *r = ecc_chip.double(ctx, r); + *r = ecc_chip.double(ctx, r.clone()); } assert!(pseudo_binary_encoding[i] <= 1 && pseudo_binary_encoding[i] >= -1); @@ -361,7 +363,7 @@ pub fn multi_miller_loop_BN( (r, sign_b), a, ); - *r = ecc_chip.add_unequal(ctx, r, sign_b, false); + *r = ecc_chip.add_unequal(ctx, r.clone(), sign_b, false); } } if i == 0 { @@ -378,11 +380,11 @@ pub fn multi_miller_loop_BN( let c3 = ecc_chip.field_chip.load_constant(ctx, c3); // finish multiplying remaining line functions outside the loop - for (r, &(a, b)) in r.iter_mut().zip(pairs.iter()) { - let b_1 = twisted_frobenius::(ecc_chip, ctx, b, &c2, &c3); - let neg_b_2 = neg_twisted_frobenius::(ecc_chip, ctx, &b_1, &c2, &c3); - f = fp12_multiply_with_line_unequal::(ecc_chip.field_chip(), ctx, &f, (r, &b_1), a); - *r = ecc_chip.add_unequal(ctx, r, &b_1, false); + for (r, (a, b)) in r.iter_mut().zip(pairs) { + let b_1 = twisted_frobenius(ecc_chip, ctx, b, &c2, &c3); + let neg_b_2 = neg_twisted_frobenius(ecc_chip, ctx, &b_1, &c2, &c3); + f = fp12_multiply_with_line_unequal(ecc_chip.field_chip(), ctx, &f, (r, &b_1), a); + *r = ecc_chip.add_unequal(ctx, r.clone(), b_1, false); f = fp12_multiply_with_line_unequal::(ecc_chip.field_chip(), ctx, &f, (r, &neg_b_2), a); } f @@ -398,18 +400,21 @@ pub fn multi_miller_loop_BN( pub fn twisted_frobenius( ecc_chip: &EccChip>, ctx: &mut Context, - Q: &EcPoint>, - c2: &FqPoint, - c3: &FqPoint, + Q: impl Into>>, + c2: impl Into>, + c3: impl Into>, ) -> EcPoint> { - assert_eq!(c2.coeffs.len(), 2); - assert_eq!(c3.coeffs.len(), 2); - - let frob_x = ecc_chip.field_chip.conjugate(ctx, &Q.x); - let frob_y = ecc_chip.field_chip.conjugate(ctx, &Q.y); - let out_x = ecc_chip.field_chip.mul(ctx, c2, &frob_x); - let out_y = ecc_chip.field_chip.mul(ctx, c3, &frob_y); - EcPoint::construct(out_x, out_y) + let Q = Q.into(); + let c2 = c2.into(); + let c3 = c3.into(); + assert_eq!(c2.0.len(), 2); + assert_eq!(c3.0.len(), 2); + + let frob_x = ecc_chip.field_chip.conjugate(ctx, Q.x); + let frob_y = ecc_chip.field_chip.conjugate(ctx, Q.y); + let out_x = ecc_chip.field_chip.mul(ctx, c2, frob_x); + let out_y = ecc_chip.field_chip.mul(ctx, c3, frob_y); + EcPoint::new(out_x, out_y) } // Frobenius coefficient coeff[1][j] = ((9+u)^{(p-1)/6})^j @@ -421,18 +426,21 @@ pub fn twisted_frobenius( pub fn neg_twisted_frobenius( ecc_chip: &EccChip>, ctx: &mut Context, - Q: &EcPoint>, - c2: &FqPoint, - c3: &FqPoint, + Q: impl Into>>, + c2: impl Into>, + c3: impl Into>, ) -> EcPoint> { - assert_eq!(c2.coeffs.len(), 2); - assert_eq!(c3.coeffs.len(), 2); - - let frob_x = ecc_chip.field_chip.conjugate(ctx, &Q.x); - let neg_frob_y = ecc_chip.field_chip.neg_conjugate(ctx, &Q.y); - let out_x = ecc_chip.field_chip.mul(ctx, c2, &frob_x); - let out_y = ecc_chip.field_chip.mul(ctx, c3, &neg_frob_y); - EcPoint::construct(out_x, out_y) + let Q = Q.into(); + let c2 = c2.into(); + let c3 = c3.into(); + assert_eq!(c2.0.len(), 2); + assert_eq!(c3.0.len(), 2); + + let frob_x = ecc_chip.field_chip.conjugate(ctx, Q.x); + let neg_frob_y = ecc_chip.field_chip.neg_conjugate(ctx, Q.y); + let out_x = ecc_chip.field_chip.mul(ctx, c2, frob_x); + let out_y = ecc_chip.field_chip.mul(ctx, c3, neg_frob_y); + EcPoint::new(out_x, out_y) } // To avoid issues with mutably borrowing twice (not allowed in Rust), we only store fp_chip and construct g2_chip and fp12_chip in scope when needed for temporary mutable borrows @@ -445,19 +453,23 @@ impl<'chip, F: PrimeField> PairingChip<'chip, F> { Self { fp_chip } } - pub fn load_private_g1(&self, ctx: &mut Context, point: G1Affine) -> EcPoint> { + pub fn load_private_g1_unchecked( + &self, + ctx: &mut Context, + point: G1Affine, + ) -> EcPoint> { let g1_chip = EccChip::new(self.fp_chip); - g1_chip.load_private(ctx, (point.x, point.y)) + g1_chip.load_private_unchecked(ctx, (point.x, point.y)) } - pub fn load_private_g2( + pub fn load_private_g2_unchecked( &self, ctx: &mut Context, point: G2Affine, - ) -> EcPoint>> { - let fp2_chip = Fp2Chip::::new(self.fp_chip); + ) -> EcPoint> { + let fp2_chip = Fp2Chip::new(self.fp_chip); let g2_chip = EccChip::new(&fp2_chip); - g2_chip.load_private(ctx, (point.x, point.y)) + g2_chip.load_private_unchecked(ctx, (point.x, point.y)) } pub fn miller_loop( @@ -492,7 +504,7 @@ impl<'chip, F: PrimeField> PairingChip<'chip, F> { ) } - pub fn final_exp(&self, ctx: &mut Context, f: &FqPoint) -> FqPoint { + pub fn final_exp(&self, ctx: &mut Context, f: FqPoint) -> FqPoint { let fp12_chip = Fp12Chip::::new(self.fp_chip); fp12_chip.final_exp(ctx, f) } @@ -507,6 +519,6 @@ impl<'chip, F: PrimeField> PairingChip<'chip, F> { let f0 = self.miller_loop(ctx, Q, P); let fp12_chip = Fp12Chip::::new(self.fp_chip); // final_exp implemented in final_exp module - fp12_chip.final_exp(ctx, &f0) + fp12_chip.final_exp(ctx, f0) } } diff --git a/halo2-ecc/src/bn254/tests/ec_add.rs b/halo2-ecc/src/bn254/tests/ec_add.rs index 30c52aa5..a902ce3c 100644 --- a/halo2-ecc/src/bn254/tests/ec_add.rs +++ b/halo2-ecc/src/bn254/tests/ec_add.rs @@ -33,13 +33,14 @@ fn g2_add_test(ctx: &mut Context, params: CircuitParams, _poin let fp2_chip = Fp2Chip::::new(&fp_chip); let g2_chip = EccChip::new(&fp2_chip); - let points = _points.iter().map(|pt| g2_chip.assign_point(ctx, *pt)).collect::>(); + let points = + _points.iter().map(|pt| g2_chip.assign_point_unchecked(ctx, *pt)).collect::>(); - let acc = g2_chip.sum::(ctx, points.iter()); + let acc = g2_chip.sum::(ctx, points); let answer = _points.iter().fold(G2Affine::identity(), |a, b| (a + b).to_affine()); - let x = fp2_chip.get_assigned_value(&acc.x); - let y = fp2_chip.get_assigned_value(&acc.y); + let x = fp2_chip.get_assigned_value(&acc.x.into()); + let y = fp2_chip.get_assigned_value(&acc.y.into()); assert_eq!(answer.x, x); assert_eq!(answer.y, y); } diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index b3769301..a8f039c2 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -3,8 +3,6 @@ use std::{ io::{BufRead, BufReader}, }; -#[allow(unused_imports)] -use crate::ecc::fixed_base::FixedEcPoint; use crate::fields::{FpStrategy, PrimeField}; use super::*; @@ -21,6 +19,7 @@ use halo2_base::{ halo2_proofs::halo2curves::bn256::G1, utils::fs::gen_srs, }; +use itertools::Itertools; use rand_core::OsRng; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] @@ -62,14 +61,15 @@ fn fixed_base_msm_test( } let msm_answer = elts.into_iter().reduce(|a, b| a + b).unwrap().to_affine(); - let msm_x = msm.x.value; - let msm_y = msm.y.value; - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x).into()); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y).into()); + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } fn random_fixed_base_msm_circuit( params: MSMCircuitParams, + bases: Vec, // bases are fixed in vkey so don't randomly generate stage: CircuitBuilderStage, break_points: Option, ) -> RangeCircuitBuilder { @@ -80,8 +80,7 @@ fn random_fixed_base_msm_circuit( CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), }; - let (bases, scalars): (Vec<_>, Vec<_>) = - (0..params.batch_size).map(|_| (G1Affine::random(OsRng), Fr::random(OsRng))).unzip(); + let scalars = (0..params.batch_size).map(|_| Fr::random(OsRng)).collect_vec(); let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); fixed_base_msm_test(&mut builder, params, bases, scalars); @@ -108,7 +107,8 @@ fn test_fixed_base_msm() { ) .unwrap(); - let circuit = random_fixed_base_msm_circuit(params, CircuitBuilderStage::Mock, None); + let bases = (0..params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); + let circuit = random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -134,8 +134,13 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let params = gen_srs(k); println!("{bench_params:?}"); - let circuit = - random_fixed_base_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None); + let bases = (0..bench_params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); + let circuit = random_fixed_base_msm_circuit( + bench_params, + bases.clone(), + CircuitBuilderStage::Keygen, + None, + ); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; @@ -151,6 +156,7 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let proof_time = start_timer!(|| "Proving time"); let circuit = random_fixed_base_msm_circuit( bench_params, + bases, CircuitBuilderStage::Prover, Some(break_points), ); diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index 86334f25..cfc7d40f 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -47,8 +47,10 @@ fn msm_test( let ctx = builder.main(0); let scalars_assigned = scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); - let bases_assigned = - bases.iter().map(|base| ecc_chip.load_private(ctx, (base.x, base.y))).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); let msm = ecc_chip.variable_base_msm_in::( builder, @@ -67,10 +69,10 @@ fn msm_test( .unwrap() .to_affine(); - let msm_x = msm.x.value; - let msm_y = msm.y.value; - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x).into()); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y).into()); + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } fn random_msm_circuit( diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs index eb82c82c..052edea4 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -1,20 +1,13 @@ use crate::fields::FpStrategy; -use ff::{Field, PrimeField}; -use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, - utils::fs::gen_srs, + RangeChip, }; use rand_core::OsRng; -use std::{ - fs::{self, File}, - io::{BufRead, BufReader}, -}; +use std::fs::File; use super::*; @@ -115,16 +108,15 @@ fn test_msm1() { ) .unwrap(); params.batch_size = 3; - + let random_point = G1Affine::random(OsRng); let bases = vec![random_point, random_point, random_point]; - let scalars = vec![Fr::one(),Fr::one(),-Fr::one()-Fr::one()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } - #[test] fn test_msm2() { let path = "configs/bn254/msm_circuit.config"; @@ -133,16 +125,15 @@ fn test_msm2() { ) .unwrap(); params.batch_size = 3; - + let random_point = G1Affine::random(OsRng); let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; - let scalars = vec![Fr::one(),Fr::one(),-Fr::one()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } - #[test] fn test_msm3() { let path = "configs/bn254/msm_circuit.config"; @@ -151,16 +142,20 @@ fn test_msm3() { ) .unwrap(); params.batch_size = 4; - + let random_point = G1Affine::random(OsRng); - let bases = vec![random_point, random_point, random_point, (random_point + random_point + random_point).to_affine()]; - let scalars = vec![Fr::one(),Fr::one(),Fr::one(),-Fr::one()]; + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } - #[test] fn test_msm4() { let path = "configs/bn254/msm_circuit.config"; @@ -169,10 +164,15 @@ fn test_msm4() { ) .unwrap(); params.batch_size = 4; - + let generator_point = G1Affine::generator(); - let bases = vec![generator_point, generator_point, generator_point, (generator_point + generator_point + generator_point).to_affine()]; - let scalars = vec![Fr::one(),Fr::one(),Fr::one(),-Fr::one()]; + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); @@ -187,10 +187,11 @@ fn test_msm5() { ) .unwrap(); params.batch_size = 4; - + let random_point = G1Affine::random(OsRng); - let bases = vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; - let scalars = vec![-Fr::one(),-Fr::one(),Fr::one(),Fr::one()]; + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); diff --git a/halo2-ecc/src/bn254/tests/pairing.rs b/halo2-ecc/src/bn254/tests/pairing.rs index 703736b7..37f82684 100644 --- a/halo2-ecc/src/bn254/tests/pairing.rs +++ b/halo2-ecc/src/bn254/tests/pairing.rs @@ -43,8 +43,8 @@ fn pairing_test( let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let chip = PairingChip::new(&fp_chip); - let P_assigned = chip.load_private_g1(ctx, P); - let Q_assigned = chip.load_private_g2(ctx, Q); + let P_assigned = chip.load_private_g1_unchecked(ctx, P); + let Q_assigned = chip.load_private_g2_unchecked(ctx, Q); // test optimal ate pairing let f = chip.pairing(ctx, &Q_assigned, &P_assigned); @@ -52,7 +52,10 @@ fn pairing_test( let actual_f = pairing(&P, &Q); let fp12_chip = Fp12Chip::new(&fp_chip); // cannot directly compare f and actual_f because `Gt` has private field `Fq12` - assert_eq!(format!("Gt({:?})", fp12_chip.get_assigned_value(&f)), format!("{actual_f:?}")); + assert_eq!( + format!("Gt({:?})", fp12_chip.get_assigned_value(&f.into())), + format!("{actual_f:?}") + ); } fn random_pairing_circuit( diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index 874c185f..d7406a17 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -1,102 +1,106 @@ -use crate::bigint::{big_less_than, CRTInteger}; +use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; + +use crate::bigint::{big_is_equal, big_less_than, FixedOverflowInteger, ProperCrtUint}; use crate::fields::{fp::FpChip, FieldChip, PrimeField}; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::CurveAffineExt, - AssignedValue, Context, -}; -use super::fixed_base; -use super::{ec_add_unequal, scalar_multiply, EcPoint}; +use super::{fixed_base, EccChip}; +use super::{scalar_multiply, EcPoint}; // CF is the coordinate field of GA // SF is the scalar field of GA // p = coordinate field modulus // n = scalar field modulus // Only valid when p is very close to n in size (e.g. for Secp256k1) +// Assumes `r, s` are proper CRT integers +/// **WARNING**: Only use this function if `1 / (p - n)` is very small (e.g., < 2-100) pub fn ecdsa_verify_no_pubkey_check( - base_chip: &FpChip, + chip: &EccChip>, ctx: &mut Context, - pubkey: &EcPoint as FieldChip>::FieldPoint>, - r: &CRTInteger, - s: &CRTInteger, - msghash: &CRTInteger, + pubkey: EcPoint as FieldChip>::FieldPoint>, + r: ProperCrtUint, + s: ProperCrtUint, + msghash: ProperCrtUint, var_window_bits: usize, fixed_window_bits: usize, ) -> AssignedValue where GA: CurveAffineExt, { + // Following https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm + let base_chip = chip.field_chip; let scalar_chip = FpChip::::new(base_chip.range, base_chip.limb_bits, base_chip.num_limbs); - let n = scalar_chip.load_constant(ctx, scalar_chip.p.to_biguint().unwrap()); + let n = scalar_chip.p.to_biguint().unwrap(); + let n = FixedOverflowInteger::from_native(&n, scalar_chip.num_limbs, scalar_chip.limb_bits); + let n = n.assign(ctx); // check r,s are in [1, n - 1] - let r_valid = scalar_chip.is_soft_nonzero(ctx, r); - let s_valid = scalar_chip.is_soft_nonzero(ctx, s); + let r_valid = scalar_chip.is_soft_nonzero(ctx, &r); + let s_valid = scalar_chip.is_soft_nonzero(ctx, &s); // compute u1 = m s^{-1} mod n and u2 = r s^{-1} mod n - let u1 = scalar_chip.divide(ctx, msghash, s); - let u2 = scalar_chip.divide(ctx, r, s); - - //let r_crt = scalar_chip.to_crt(ctx, r)?; + let u1 = scalar_chip.divide_unsafe(ctx, msghash, &s); + let u2 = scalar_chip.divide_unsafe(ctx, &r, s); // compute u1 * G and u2 * pubkey - let u1_mul = fixed_base::scalar_multiply::( + let u1_mul = fixed_base::scalar_multiply( base_chip, ctx, &GA::generator(), - u1.truncation.limbs.clone(), + u1.limbs().to_vec(), base_chip.limb_bits, fixed_window_bits, + true, // we can call it with scalar_is_safe = true because of the u1_small check below ); - let u2_mul = scalar_multiply::( + let u2_mul = scalar_multiply( base_chip, ctx, pubkey, - u2.truncation.limbs.clone(), + u2.limbs().to_vec(), base_chip.limb_bits, var_window_bits, + true, // we can call it with scalar_is_safe = true because of the u2_small check below ); - // check u1 * G and u2 * pubkey are not negatives and not equal - // TODO: Technically they could be equal for a valid signature, but this happens with vanishing probability - // for an ECDSA signature constructed in a standard way + // check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey + // check (u1 * G).x != (u2 * pubkey).x or (u1 * G).y == (u2 * pubkey).y // coordinates of u1_mul and u2_mul are in proper bigint form, and lie in but are not constrained to [0, n) // we therefore need hard inequality here - let u1_u2_x_eq = base_chip.is_equal(ctx, &u1_mul.x, &u2_mul.x); - let u1_u2_not_neg = base_chip.range.gate().not(ctx, u1_u2_x_eq); + let x_eq = base_chip.is_equal(ctx, &u1_mul.x, &u2_mul.x); + let x_neq = base_chip.gate().not(ctx, x_eq); + let y_eq = base_chip.is_equal(ctx, &u1_mul.y, &u2_mul.y); + let u1g_u2pk_not_neg = base_chip.gate().or(ctx, x_neq, y_eq); // compute (x1, y1) = u1 * G + u2 * pubkey and check (r mod n) == x1 as integers + // because it is possible for u1 * G == u2 * pubkey, we must use `EccChip::sum` + let sum = chip.sum::(ctx, [u1_mul, u2_mul]); // WARNING: For optimization reasons, does not reduce x1 mod n, which is // invalid unless p is very close to n in size. - base_chip.enforce_less_than_p(ctx, u1_mul.x()); - base_chip.enforce_less_than_p(ctx, u2_mul.x()); - let sum = ec_add_unequal(base_chip, ctx, &u1_mul, &u2_mul, false); - let equal_check = base_chip.is_equal(ctx, &sum.x, r); + // enforce x1 < n + let x1 = scalar_chip.enforce_less_than(ctx, sum.x); + let equal_check = big_is_equal::assign(base_chip.gate(), ctx, x1.0, r); - // TODO: maybe the big_less_than is optional? - let u1_small = big_less_than::assign::( + let u1_small = big_less_than::assign( base_chip.range(), ctx, - &u1.truncation, - &n.truncation, + u1, + n.clone(), base_chip.limb_bits, base_chip.limb_bases[1], ); - let u2_small = big_less_than::assign::( + let u2_small = big_less_than::assign( base_chip.range(), ctx, - &u2.truncation, - &n.truncation, + u2, + n, base_chip.limb_bits, base_chip.limb_bases[1], ); - // check (r in [1, n - 1]) and (s in [1, n - 1]) and (u1_mul != - u2_mul) and (r == x1 mod n) + // check (r in [1, n - 1]) and (s in [1, n - 1]) and (u1 * G != - u2 * pubkey) and (r == x1 mod n) let res1 = base_chip.gate().and(ctx, r_valid, s_valid); let res2 = base_chip.gate().and(ctx, res1, u1_small); let res3 = base_chip.gate().and(ctx, res2, u2_small); - let res4 = base_chip.gate().and(ctx, res3, u1_u2_not_neg); + let res4 = base_chip.gate().and(ctx, res3, u1g_u2pk_not_neg); let res5 = base_chip.gate().and(ctx, res4, equal_check); res5 } diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index 52e6634f..dc67b8d6 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,62 +1,25 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; -use crate::halo2_proofs::arithmetic::CurveAffine; -use crate::{ - bigint::{CRTInteger, FixedCRTInteger}, - fields::{PrimeField, PrimeFieldChip, Selectable}, -}; +use crate::ecc::ec_sub_strict; +use crate::fields::{FieldChip, PrimeField, Selectable}; use group::Curve; -use halo2_base::gates::builder::GateThreadBuilder; -use halo2_base::{ - gates::GateInstructions, - utils::{fe_to_biguint, CurveAffineExt}, - AssignedValue, Context, -}; +use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; +use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use itertools::Itertools; use rayon::prelude::*; -use std::{cmp::min, marker::PhantomData}; - -// this only works for curves GA with base field of prime order -#[derive(Clone, Debug)] -pub struct FixedEcPoint { - pub x: FixedCRTInteger, // limbs in `F` and value in `BigUint` - pub y: FixedCRTInteger, - _marker: PhantomData, -} - -impl FixedEcPoint -where - C::Base: PrimeField, -{ - pub fn construct(x: FixedCRTInteger, y: FixedCRTInteger) -> Self { - Self { x, y, _marker: PhantomData } - } - - pub fn from_curve(point: C, num_limbs: usize, limb_bits: usize) -> Self { - let (x, y) = point.into_coordinates(); - let x = FixedCRTInteger::from_native(fe_to_biguint(&x), num_limbs, limb_bits); - let y = FixedCRTInteger::from_native(fe_to_biguint(&y), num_limbs, limb_bits); - Self::construct(x, y) - } - - pub fn assign(self, chip: &FC, ctx: &mut Context) -> EcPoint - where - FC: PrimeFieldChip>, - { - let assigned_x = self.x.assign(ctx, chip.limb_bits(), chip.native_modulus()); - let assigned_y = self.y.assign(ctx, chip.limb_bits(), chip.native_modulus()); - EcPoint::construct(assigned_x, assigned_y) - } -} - -// computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant) -// - `scalar` is represented as a reference array of `AssignedCell`s -// - `scalar = sum_i scalar_i * 2^{max_bits * i}` -// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` -// assumes: -// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) -// - `max_bits <= modulus::.bits()` - +use std::cmp::min; + +/// Computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant) +/// - `scalar` is represented as a non-empty reference array of `AssignedValue`s +/// - `scalar = sum_i scalar_i * 2^{max_bits * i}` +/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` +/// +/// # Assumptions +/// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) +/// - `scalar > 0` +/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) +/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) +/// - `max_bits <= modulus::.bits()` pub fn scalar_multiply( chip: &FC, ctx: &mut Context, @@ -64,20 +27,19 @@ pub fn scalar_multiply( scalar: Vec>, max_bits: usize, window_bits: usize, + scalar_is_safe: bool, ) -> EcPoint where F: PrimeField, C: CurveAffineExt, - C::Base: PrimeField, - FC: PrimeFieldChip> - + Selectable, + FC: FieldChip + Selectable, { if point.is_identity().into() { - let point = FixedEcPoint::from_curve(*point, chip.num_limbs(), chip.limb_bits()); - return FixedEcPoint::assign(point, chip, ctx); + let zero = chip.load_constant(ctx, C::Base::zero()); + return EcPoint::new(zero.clone(), zero); } - debug_assert!(!scalar.is_empty()); - debug_assert!((max_bits as u32) <= F::NUM_BITS); + assert!(!scalar.is_empty()); + assert!((max_bits as u32) <= F::NUM_BITS); let total_bits = max_bits * scalar.len(); let num_windows = (total_bits + window_bits - 1) / window_bits; @@ -112,8 +74,9 @@ where let cached_points = cached_points_affine .into_iter() .map(|point| { - let point = FixedEcPoint::from_curve(point, chip.num_limbs(), chip.limb_bits()); - FixedEcPoint::assign(point, chip, ctx) + let (x, y) = point.into_coordinates(); + let [x, y] = [x, y].map(|x| chip.load_constant(ctx, x)); + EcPoint::new(x, y) }) .collect_vec(); @@ -131,11 +94,11 @@ where let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied()); // are we just adding a window of all 0s? if so, skip let is_zero_window = chip.gate().is_zero(ctx, bit_sum); - let add_point = ec_select_from_bits::(chip, ctx, cached_point_window, bit_window); + let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(chip, ctx, &curr_point, &sum, is_zero_window); - Some(ec_select(chip, ctx, &zero_sum, &add_point, is_started)) + let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, !scalar_is_safe); + let zero_sum = ec_select(chip, ctx, curr_point, sum, is_zero_window); + Some(ec_select(chip, ctx, zero_sum, add_point, is_started)) } else { Some(add_point) }; @@ -151,116 +114,14 @@ where // basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation // we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO) -pub fn msm( - chip: &EccChip, - ctx: &mut Context, - points: &[C], - scalars: Vec>>, - max_scalar_bits_per_cell: usize, - window_bits: usize, -) -> EcPoint -where - F: PrimeField, - C: CurveAffineExt, - C::Base: PrimeField, - FC: PrimeFieldChip> - + Selectable, -{ - assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS); - let scalar_len = scalars[0].len(); - let total_bits = max_scalar_bits_per_cell * scalar_len; - let num_windows = (total_bits + window_bits - 1) / window_bits; - - // `cached_points` is a flattened 2d vector - // first we compute all cached points in Jacobian coordinates since it's fastest - let cached_points_jacobian = points - .iter() - .flat_map(|point| { - let base_pt = point.to_curve(); - // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1} - let mut increment = base_pt; - (0..num_windows) - .flat_map(|i| { - let mut curr = increment; - let cache_vec = std::iter::once(increment) - .chain((1..(1usize << min(window_bits, total_bits - i * window_bits))).map( - |_| { - let prev = curr; - curr += increment; - prev - }, - )) - .collect_vec(); - increment = curr; - cache_vec - }) - .collect_vec() - }) - .collect_vec(); - // for use in circuits we need affine coordinates, so we do a batch normalize: this is much more efficient than calling `to_affine` one by one since field inversion is very expensive - // initialize to all 0s - let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()]; - C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); - - let field_chip = chip.field_chip(); - let cached_points = cached_points_affine - .into_iter() - .map(|point| { - let point = - FixedEcPoint::from_curve(point, field_chip.num_limbs(), field_chip.limb_bits()); - point.assign(field_chip, ctx) - }) - .collect_vec(); - - let bits = scalars - .into_iter() - .flat_map(|scalar| { - assert_eq!(scalar.len(), scalar_len); - scalar - .into_iter() - .flat_map(|scalar_chunk| { - field_chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell) - }) - .collect_vec() - }) - .collect_vec(); - - let sm = cached_points - .chunks(cached_points.len() / points.len()) - .zip(bits.chunks(total_bits)) - .map(|(cached_points, bits)| { - let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); - let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = ctx.load_zero(); - for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { - let is_zero_window = { - let sum = field_chip.gate().sum(ctx, bit_window.iter().copied()); - field_chip.gate().is_zero(ctx, sum) - }; - let add_point = - ec_select_from_bits::(field_chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(field_chip, ctx, &curr_point, &sum, is_zero_window); - Some(ec_select(field_chip, ctx, &zero_sum, &add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = field_chip.gate().not(ctx, is_zero_window); - field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) - }; - } - curr_point.unwrap() - }) - .collect_vec(); - chip.sum::(ctx, sm.iter()) -} +/// # Assumptions +/// * `points.len() = scalars.len()` +/// * `scalars[i].len() = scalars[j].len()` for all `i,j` +/// * `points` are all on the curve +/// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand +/// * The integer value of `scalars[i]` is less than the order of `points[i]` (some constraints may fail otherwise) +/// * Output may be point at infinity, in which case (0, 0) is returned pub fn msm_par( chip: &EccChip, builder: &mut GateThreadBuilder, @@ -273,10 +134,11 @@ pub fn msm_par( where F: PrimeField, C: CurveAffineExt, - C::Base: PrimeField, - FC: PrimeFieldChip> - + Selectable, + FC: FieldChip + Selectable, { + if points.is_empty() { + return chip.assign_constant_point(builder.main(phase), C::identity()); + } assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS); assert_eq!(points.len(), scalars.len()); assert!(!points.is_empty(), "fixed_base::msm_par requires at least one point"); @@ -316,32 +178,23 @@ where C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); - let witness_gen_only = builder.witness_gen_only(); let zero = builder.main(phase).load_zero(); - let thread_ids = (0..scalars.len()).map(|_| builder.get_new_thread_id()).collect::>(); - let (new_threads, scalar_mults): (Vec<_>, Vec<_>) = cached_points_affine - .par_chunks(cached_points_affine.len() / points.len()) - .zip_eq(scalars.into_par_iter()) - .zip(thread_ids.into_par_iter()) - .map(|((cached_points, scalar), thread_id)| { - let mut thread = Context::new(witness_gen_only, thread_id); - let ctx = &mut thread; - + let scalar_mults = parallelize_in( + phase, + builder, + cached_points_affine + .chunks(cached_points_affine.len() / points.len()) + .zip_eq(scalars) + .collect(), + |ctx, (cached_points, scalar)| { let cached_points = cached_points .iter() - .map(|point| { - let point = FixedEcPoint::from_curve( - *point, - field_chip.num_limbs(), - field_chip.limb_bits(), - ); - point.assign(field_chip, ctx) - }) + .map(|point| chip.assign_constant_point(ctx, *point)) .collect_vec(); let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); - debug_assert_eq!(scalar.len(), scalar_len); + assert_eq!(scalar.len(), scalar_len); let bits = scalar .into_iter() .flat_map(|scalar_chunk| { @@ -358,11 +211,12 @@ where field_chip.gate().is_zero(ctx, sum) }; let add_point = - ec_select_from_bits::(field_chip, ctx, cached_point_window, bit_window); + ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); curr_point = if let Some(curr_point) = curr_point { + // We don't need strict mode because we assume scalars[i] is less than the order of points[i] let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(field_chip, ctx, &curr_point, &sum, is_zero_window); - Some(ec_select(field_chip, ctx, &zero_sum, &add_point, is_started)) + let zero_sum = ec_select(field_chip, ctx, curr_point, sum, is_zero_window); + Some(ec_select(field_chip, ctx, zero_sum, add_point, is_started)) } else { Some(add_point) }; @@ -373,9 +227,16 @@ where field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) }; } - (thread, curr_point.unwrap()) - }) - .unzip(); - builder.threads[phase].extend(new_threads); - chip.sum::(builder.main(phase), scalar_mults.iter()) + (curr_point.unwrap(), is_started) + }, + ); + let ctx = builder.main(phase); + // sum `scalar_mults` but take into account possiblity of identity points + let any_point = chip.load_random_point::(ctx); + let mut acc = any_point.clone(); + for (point, is_not_identity) in scalar_mults { + let new_acc = chip.add_unequal(ctx, &acc, point, true); + acc = chip.select(ctx, new_acc, acc, is_not_identity); + } + ec_sub_strict(field_chip, ctx, acc, any_point) } diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index be40b85f..208eee11 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -1,6 +1,5 @@ #![allow(non_snake_case)] -use crate::bigint::CRTInteger; -use crate::fields::{fp::FpChip, FieldChip, PrimeField, PrimeFieldChip, Selectable}; +use crate::fields::{fp::FpChip, FieldChip, PrimeField, Selectable}; use crate::halo2_proofs::arithmetic::CurveAffine; use group::{Curve, Group}; use halo2_base::gates::builder::GateThreadBuilder; @@ -33,8 +32,17 @@ impl Clone for EcPoint { } } +// Improve readability by allowing `&EcPoint` to be converted to `EcPoint` via cloning +impl<'a, F: PrimeField, FieldPoint: Clone> From<&'a EcPoint> + for EcPoint +{ + fn from(value: &'a EcPoint) -> Self { + value.clone() + } +} + impl EcPoint { - pub fn construct(x: FieldPoint, y: FieldPoint) -> Self { + pub fn new(x: FieldPoint, y: FieldPoint) -> Self { Self { x, y, _marker: PhantomData } } @@ -47,6 +55,83 @@ impl EcPoint { } } +/// An elliptic curve point where it is easy to compare the x-coordinate of two points +#[derive(Clone, Debug)] +pub struct StrictEcPoint> { + pub x: FC::ReducedFieldPoint, + pub y: FC::FieldPoint, + _marker: PhantomData, +} + +impl> StrictEcPoint { + pub fn new(x: FC::ReducedFieldPoint, y: FC::FieldPoint) -> Self { + Self { x, y, _marker: PhantomData } + } +} + +impl> From> for EcPoint { + fn from(value: StrictEcPoint) -> Self { + Self::new(value.x.into(), value.y) + } +} + +impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> + for EcPoint +{ + fn from(value: &'a StrictEcPoint) -> Self { + value.clone().into() + } +} + +/// An elliptic curve point where the x-coordinate has already been constrained to be reduced or not. +/// In the reduced case one can more optimally compare equality of x-coordinates. +#[derive(Clone, Debug)] +pub enum ComparableEcPoint> { + Strict(StrictEcPoint), + NonStrict(EcPoint), +} + +impl> From> for ComparableEcPoint { + fn from(pt: StrictEcPoint) -> Self { + Self::Strict(pt) + } +} + +impl> From> + for ComparableEcPoint +{ + fn from(pt: EcPoint) -> Self { + Self::NonStrict(pt) + } +} + +impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> + for ComparableEcPoint +{ + fn from(pt: &'a StrictEcPoint) -> Self { + Self::Strict(pt.clone()) + } +} + +impl<'a, F: PrimeField, FC: FieldChip> From<&'a EcPoint> + for ComparableEcPoint +{ + fn from(pt: &'a EcPoint) -> Self { + Self::NonStrict(pt.clone()) + } +} + +impl> From> + for EcPoint +{ + fn from(pt: ComparableEcPoint) -> Self { + match pt { + ComparableEcPoint::Strict(pt) => Self::new(pt.x.into(), pt.y), + ComparableEcPoint::NonStrict(pt) => pt, + } + } +} + // Implements: // Given P = (x_1, y_1) and Q = (x_2, y_2), ecc points over the field F_p // assume x_1 != x_2 @@ -57,37 +142,61 @@ impl EcPoint { // x_3 = lambda^2 - x_1 - x_2 (mod p) // y_3 = lambda (x_1 - x_3) - y_1 mod p // -/// For optimization reasons, we assume that if you are using this with `is_strict = true`, then you have already called `chip.enforce_less_than_p` on both `P.x` and `Q.x` +/// If `is_strict = true`, then this function constrains that `P.x != Q.x`. +/// If you are calling this with `is_strict = false`, you must ensure that `P.x != Q.x` by some external logic (such +/// as a mathematical theorem). +/// +/// # Assumptions +/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) pub fn ec_add_unequal>( chip: &FC, ctx: &mut Context, - P: &EcPoint, - Q: &EcPoint, + P: impl Into>, + Q: impl Into>, is_strict: bool, ) -> EcPoint { - if is_strict { - // constrains that P.x != Q.x - let x_is_equal = chip.is_equal_unenforced(ctx, &P.x, &Q.x); - chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::zero()); - } + let (P, Q) = check_points_are_unequal(chip, ctx, P, Q, is_strict); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.sub_no_carry(ctx, &Q.y, &P.y); - let lambda = chip.divide(ctx, &dy, &dx); + let dy = chip.sub_no_carry(ctx, Q.y, &P.y); + let lambda = chip.divide_unsafe(ctx, dy, dx); // x_3 = lambda^2 - x_1 - x_2 (mod p) let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); - let lambda_sq_minus_px = chip.sub_no_carry(ctx, &lambda_sq, &P.x); - let x_3_no_carry = chip.sub_no_carry(ctx, &lambda_sq_minus_px, &Q.x); - let x_3 = chip.carry_mod(ctx, &x_3_no_carry); + let lambda_sq_minus_px = chip.sub_no_carry(ctx, lambda_sq, &P.x); + let x_3_no_carry = chip.sub_no_carry(ctx, lambda_sq_minus_px, Q.x); + let x_3 = chip.carry_mod(ctx, x_3_no_carry); // y_3 = lambda (x_1 - x_3) - y_1 mod p - let dx_13 = chip.sub_no_carry(ctx, &P.x, &x_3); - let lambda_dx_13 = chip.mul_no_carry(ctx, &lambda, &dx_13); - let y_3_no_carry = chip.sub_no_carry(ctx, &lambda_dx_13, &P.y); - let y_3 = chip.carry_mod(ctx, &y_3_no_carry); + let dx_13 = chip.sub_no_carry(ctx, P.x, &x_3); + let lambda_dx_13 = chip.mul_no_carry(ctx, lambda, dx_13); + let y_3_no_carry = chip.sub_no_carry(ctx, lambda_dx_13, P.y); + let y_3 = chip.carry_mod(ctx, y_3_no_carry); + + EcPoint::new(x_3, y_3) +} - EcPoint::construct(x_3, y_3) +/// If `do_check = true`, then this function constrains that `P.x != Q.x`. +/// Otherwise does nothing. +fn check_points_are_unequal>( + chip: &FC, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, + do_check: bool, +) -> (EcPoint /*P */, EcPoint /*Q */) { + let P = P.into(); + let Q = Q.into(); + if do_check { + // constrains that P.x != Q.x + let [x1, x2] = [&P, &Q].map(|pt| match pt { + ComparableEcPoint::Strict(pt) => pt.x.clone(), + ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), + }); + let x_is_equal = chip.is_equal_unenforced(ctx, x1, x2); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + } + (EcPoint::from(P), EcPoint::from(Q)) } // Implements: @@ -99,43 +208,80 @@ pub fn ec_add_unequal>( // y_3 = lambda (x_1 - x_3) - y_1 mod p // Assumes that P !=Q and Q != (P - Q) // -/// For optimization reasons, we assume that if you are using this with `is_strict = true`, then you have already called `chip.enforce_less_than_p` on both `P.x` and `Q.x` +/// If `is_strict = true`, then this function constrains that `P.x != Q.x`. +/// If you are calling this with `is_strict = false`, you must ensure that `P.x != Q.x` by some external logic (such +/// as a mathematical theorem). +/// +/// # Assumptions +/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) pub fn ec_sub_unequal>( chip: &FC, ctx: &mut Context, - P: &EcPoint, - Q: &EcPoint, + P: impl Into>, + Q: impl Into>, is_strict: bool, ) -> EcPoint { - if is_strict { - // constrains that P.x != Q.x - let x_is_equal = chip.is_equal_unenforced(ctx, &P.x, &Q.x); - chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::zero()); - } + let (P, Q) = check_points_are_unequal(chip, ctx, P, Q, is_strict); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.add_no_carry(ctx, &Q.y, &P.y); + let dy = chip.add_no_carry(ctx, Q.y, &P.y); - let lambda = chip.neg_divide(ctx, &dy, &dx); + let lambda = chip.neg_divide_unsafe(ctx, &dy, &dx); // (x_2 - x_1) * lambda + y_2 + y_1 = 0 (mod p) - let lambda_dx = chip.mul_no_carry(ctx, &lambda, &dx); - let lambda_dx_plus_dy = chip.add_no_carry(ctx, &lambda_dx, &dy); - chip.check_carry_mod_to_zero(ctx, &lambda_dx_plus_dy); + let lambda_dx = chip.mul_no_carry(ctx, &lambda, dx); + let lambda_dx_plus_dy = chip.add_no_carry(ctx, lambda_dx, dy); + chip.check_carry_mod_to_zero(ctx, lambda_dx_plus_dy); // x_3 = lambda^2 - x_1 - x_2 (mod p) let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); - let lambda_sq_minus_px = chip.sub_no_carry(ctx, &lambda_sq, &P.x); - let x_3_no_carry = chip.sub_no_carry(ctx, &lambda_sq_minus_px, &Q.x); - let x_3 = chip.carry_mod(ctx, &x_3_no_carry); + let lambda_sq_minus_px = chip.sub_no_carry(ctx, lambda_sq, &P.x); + let x_3_no_carry = chip.sub_no_carry(ctx, lambda_sq_minus_px, Q.x); + let x_3 = chip.carry_mod(ctx, x_3_no_carry); // y_3 = lambda (x_1 - x_3) - y_1 mod p - let dx_13 = chip.sub_no_carry(ctx, &P.x, &x_3); - let lambda_dx_13 = chip.mul_no_carry(ctx, &lambda, &dx_13); - let y_3_no_carry = chip.sub_no_carry(ctx, &lambda_dx_13, &P.y); - let y_3 = chip.carry_mod(ctx, &y_3_no_carry); + let dx_13 = chip.sub_no_carry(ctx, P.x, &x_3); + let lambda_dx_13 = chip.mul_no_carry(ctx, lambda, dx_13); + let y_3_no_carry = chip.sub_no_carry(ctx, lambda_dx_13, P.y); + let y_3 = chip.carry_mod(ctx, y_3_no_carry); - EcPoint::construct(x_3, y_3) + EcPoint::new(x_3, y_3) +} + +/// Constrains `P != -Q` but allows `P == Q`, in which case output is (0,0). +/// For Weierstrass curves only. +pub fn ec_sub_strict>( + chip: &FC, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, +) -> EcPoint +where + FC: Selectable, +{ + let mut P = P.into(); + let Q = Q.into(); + // Compute curr_point - start_point, allowing for output to be identity point + let x_is_eq = chip.is_equal(ctx, P.x(), Q.x()); + let y_is_eq = chip.is_equal(ctx, P.y(), Q.y()); + let is_identity = chip.gate().and(ctx, x_is_eq, y_is_eq); + // we ONLY allow x_is_eq = true if y_is_eq is also true; this constrains P != -Q + ctx.constrain_equal(&x_is_eq, &is_identity); + + // P.x = Q.x and P.y = Q.y + // in ec_sub_unequal it will try to do -(P.y + Q.y) / (P.x - Q.x) = -2P.y / 0 + // this will cause divide_unsafe to panic when P.y != 0 + // to avoid this, we load a random pair of points and replace P with it *only if* `is_identity == true` + // we don't even check (rand_x, rand_y) is on the curve, since we don't care about the output + let mut rng = ChaCha20Rng::from_entropy(); + let [rand_x, rand_y] = [(); 2].map(|_| FC::FieldType::random(&mut rng)); + let [rand_x, rand_y] = [rand_x, rand_y].map(|x| chip.load_private(ctx, x)); + let rand_pt = EcPoint::new(rand_x, rand_y); + P = ec_select(chip, ctx, rand_pt, P, is_identity); + + let out = ec_sub_unequal(chip, ctx, P, Q, false); + let zero = chip.load_constant(ctx, FC::FieldType::zero()); + ec_select(chip, ctx, EcPoint::new(zero.clone(), zero), out, is_identity) } // Implements: @@ -150,30 +296,34 @@ pub fn ec_sub_unequal>( // we precompute lambda and constrain (2y) * lambda = 3 x^2 (mod p) // then we compute x_3 = lambda^2 - 2 x (mod p) // y_3 = lambda (x - x_3) - y (mod p) +/// # Assumptions +/// * `P.y != 0` +/// * `P` is not the point at infinity (undefined behavior otherwise) pub fn ec_double>( chip: &FC, ctx: &mut Context, - P: &EcPoint, + P: impl Into>, ) -> EcPoint { + let P = P.into(); // removed optimization that computes `2 * lambda` while assigning witness to `lambda` simultaneously, in favor of readability. The difference is just copying `lambda` once let two_y = chip.scalar_mul_no_carry(ctx, &P.y, 2); let three_x = chip.scalar_mul_no_carry(ctx, &P.x, 3); - let three_x_sq = chip.mul_no_carry(ctx, &three_x, &P.x); - let lambda = chip.divide(ctx, &three_x_sq, &two_y); + let three_x_sq = chip.mul_no_carry(ctx, three_x, &P.x); + let lambda = chip.divide_unsafe(ctx, three_x_sq, two_y); // x_3 = lambda^2 - 2 x % p let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); let two_x = chip.scalar_mul_no_carry(ctx, &P.x, 2); - let x_3_no_carry = chip.sub_no_carry(ctx, &lambda_sq, &two_x); - let x_3 = chip.carry_mod(ctx, &x_3_no_carry); + let x_3_no_carry = chip.sub_no_carry(ctx, lambda_sq, two_x); + let x_3 = chip.carry_mod(ctx, x_3_no_carry); // y_3 = lambda (x - x_3) - y % p - let dx = chip.sub_no_carry(ctx, &P.x, &x_3); - let lambda_dx = chip.mul_no_carry(ctx, &lambda, &dx); - let y_3_no_carry = chip.sub_no_carry(ctx, &lambda_dx, &P.y); - let y_3 = chip.carry_mod(ctx, &y_3_no_carry); + let dx = chip.sub_no_carry(ctx, P.x, &x_3); + let lambda_dx = chip.mul_no_carry(ctx, lambda, dx); + let y_3_no_carry = chip.sub_no_carry(ctx, lambda_dx, P.y); + let y_3 = chip.carry_mod(ctx, y_3_no_carry); - EcPoint::construct(x_3, y_3) + EcPoint::new(x_3, y_3) } /// Implements: @@ -185,124 +335,168 @@ pub fn ec_double>( // lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0) // x_res = lambda_1^2 - x_0 - x_2 // y_res = lambda_1 * (x_res - x_0) - y_0 +/// +/// # Assumptions +/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) pub fn ec_double_and_add_unequal>( chip: &FC, ctx: &mut Context, - P: &EcPoint, - Q: &EcPoint, + P: impl Into>, + Q: impl Into>, is_strict: bool, ) -> EcPoint { + let P = P.into(); + let Q = Q.into(); + let mut x_0 = None; if is_strict { // constrains that P.x != Q.x - let x_is_equal = chip.is_equal_unenforced(ctx, &P.x, &Q.x); - chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + let [x0, x1] = [&P, &Q].map(|pt| match pt { + ComparableEcPoint::Strict(pt) => pt.x.clone(), + ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), + }); + let x_is_equal = chip.is_equal_unenforced(ctx, x0.clone(), x1); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + x_0 = Some(x0); } + let P = EcPoint::from(P); + let Q = EcPoint::from(Q); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.sub_no_carry(ctx, &Q.y, &P.y); - let lambda_0 = chip.divide(ctx, &dy, &dx); + let dy = chip.sub_no_carry(ctx, Q.y, &P.y); + let lambda_0 = chip.divide_unsafe(ctx, dy, dx); // x_2 = lambda_0^2 - x_0 - x_1 (mod p) let lambda_0_sq = chip.mul_no_carry(ctx, &lambda_0, &lambda_0); - let lambda_0_sq_minus_x_0 = chip.sub_no_carry(ctx, &lambda_0_sq, &P.x); - let x_2_no_carry = chip.sub_no_carry(ctx, &lambda_0_sq_minus_x_0, &Q.x); - let x_2 = chip.carry_mod(ctx, &x_2_no_carry); + let lambda_0_sq_minus_x_0 = chip.sub_no_carry(ctx, lambda_0_sq, &P.x); + let x_2_no_carry = chip.sub_no_carry(ctx, lambda_0_sq_minus_x_0, Q.x); + let x_2 = chip.carry_mod(ctx, x_2_no_carry); if is_strict { + let x_2 = chip.enforce_less_than(ctx, x_2.clone()); // TODO: when can we remove this check? // constrains that x_2 != x_0 - let x_is_equal = chip.is_equal_unenforced(ctx, &P.x, &x_2); + let x_is_equal = chip.is_equal_unenforced(ctx, x_0.unwrap(), x_2); chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::zero()); } // lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0) let two_y_0 = chip.scalar_mul_no_carry(ctx, &P.y, 2); let x_2_minus_x_0 = chip.sub_no_carry(ctx, &x_2, &P.x); - let lambda_1_minus_lambda_0 = chip.divide(ctx, &two_y_0, &x_2_minus_x_0); - let lambda_1_no_carry = chip.add_no_carry(ctx, &lambda_0, &lambda_1_minus_lambda_0); + let lambda_1_minus_lambda_0 = chip.divide_unsafe(ctx, two_y_0, x_2_minus_x_0); + let lambda_1_no_carry = chip.add_no_carry(ctx, lambda_0, lambda_1_minus_lambda_0); // x_res = lambda_1^2 - x_0 - x_2 let lambda_1_sq_nc = chip.mul_no_carry(ctx, &lambda_1_no_carry, &lambda_1_no_carry); - let lambda_1_sq_minus_x_0 = chip.sub_no_carry(ctx, &lambda_1_sq_nc, &P.x); - let x_res_no_carry = chip.sub_no_carry(ctx, &lambda_1_sq_minus_x_0, &x_2); - let x_res = chip.carry_mod(ctx, &x_res_no_carry); + let lambda_1_sq_minus_x_0 = chip.sub_no_carry(ctx, lambda_1_sq_nc, &P.x); + let x_res_no_carry = chip.sub_no_carry(ctx, lambda_1_sq_minus_x_0, x_2); + let x_res = chip.carry_mod(ctx, x_res_no_carry); // y_res = lambda_1 * (x_res - x_0) - y_0 - let x_res_minus_x_0 = chip.sub_no_carry(ctx, &x_res, &P.x); - let lambda_1_x_res_minus_x_0 = chip.mul_no_carry(ctx, &lambda_1_no_carry, &x_res_minus_x_0); - let y_res_no_carry = chip.sub_no_carry(ctx, &lambda_1_x_res_minus_x_0, &P.y); - let y_res = chip.carry_mod(ctx, &y_res_no_carry); + let x_res_minus_x_0 = chip.sub_no_carry(ctx, &x_res, P.x); + let lambda_1_x_res_minus_x_0 = chip.mul_no_carry(ctx, lambda_1_no_carry, x_res_minus_x_0); + let y_res_no_carry = chip.sub_no_carry(ctx, lambda_1_x_res_minus_x_0, P.y); + let y_res = chip.carry_mod(ctx, y_res_no_carry); - EcPoint::construct(x_res, y_res) + EcPoint::new(x_res, y_res) } pub fn ec_select( chip: &FC, ctx: &mut Context, - P: &EcPoint, - Q: &EcPoint, + P: EcPoint, + Q: EcPoint, sel: AssignedValue, ) -> EcPoint where - FC: FieldChip + Selectable, + FC: FieldChip + Selectable, { - let Rx = chip.select(ctx, &P.x, &Q.x, sel); - let Ry = chip.select(ctx, &P.y, &Q.y, sel); - EcPoint::construct(Rx, Ry) + let Rx = chip.select(ctx, P.x, Q.x, sel); + let Ry = chip.select(ctx, P.y, Q.y, sel); + EcPoint::new(Rx, Ry) } // takes the dot product of points with sel, where each is intepreted as // a _vector_ -pub fn ec_select_by_indicator( +pub fn ec_select_by_indicator( chip: &FC, ctx: &mut Context, - points: &[EcPoint], + points: &[Pt], coeffs: &[AssignedValue], ) -> EcPoint where - FC: FieldChip + Selectable, + FC: FieldChip + Selectable, + Pt: Into> + Clone, { - let x_coords = points.iter().map(|P| P.x.clone()).collect::>(); - let y_coords = points.iter().map(|P| P.y.clone()).collect::>(); - let Rx = chip.select_by_indicator(ctx, &x_coords, coeffs); - let Ry = chip.select_by_indicator(ctx, &y_coords, coeffs); - EcPoint::construct(Rx, Ry) + let (x, y): (Vec<_>, Vec<_>) = points + .iter() + .map(|P| { + let P: EcPoint<_, _> = P.clone().into(); + (P.x, P.y) + }) + .unzip(); + let Rx = chip.select_by_indicator(ctx, &x, coeffs); + let Ry = chip.select_by_indicator(ctx, &y, coeffs); + EcPoint::new(Rx, Ry) } // `sel` is little-endian binary -pub fn ec_select_from_bits( +pub fn ec_select_from_bits( chip: &FC, ctx: &mut Context, - points: &[EcPoint], + points: &[Pt], sel: &[AssignedValue], ) -> EcPoint where - FC: FieldChip + Selectable, + FC: FieldChip + Selectable, + Pt: Into> + Clone, { let w = sel.len(); - let num_points = points.len(); - assert_eq!(1 << w, num_points); + assert_eq!(1 << w, points.len()); let coeffs = chip.range().gate().bits_to_indicator(ctx, sel); ec_select_by_indicator(chip, ctx, points, &coeffs) } -// computes [scalar] * P on y^2 = x^3 + b -// - `scalar` is represented as a reference array of `AssignedCell`s -// - `scalar = sum_i scalar_i * 2^{max_bits * i}` -// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` -// assumes: -// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) -// - `max_bits <= modulus::.bits()` -// * P has order given by the scalar field modulus +// `sel` is little-endian binary +pub fn strict_ec_select_from_bits( + chip: &FC, + ctx: &mut Context, + points: &[StrictEcPoint], + sel: &[AssignedValue], +) -> StrictEcPoint +where + FC: FieldChip + Selectable + Selectable, +{ + let w = sel.len(); + assert_eq!(1 << w, points.len()); + let coeffs = chip.range().gate().bits_to_indicator(ctx, sel); + let (x, y): (Vec<_>, Vec<_>) = points.iter().map(|pt| (pt.x.clone(), pt.y.clone())).unzip(); + let x = chip.select_by_indicator(ctx, &x, &coeffs); + let y = chip.select_by_indicator(ctx, &y, &coeffs); + StrictEcPoint::new(x, y) +} + +/// Computes `[scalar] * P` on short Weierstrass curve `y^2 = x^3 + b` +/// - `scalar` is represented as a reference array of `AssignedValue`s +/// - `scalar = sum_i scalar_i * 2^{max_bits * i}` +/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` +/// +/// # Assumptions +/// - `P` is not the point at infinity +/// - `scalar > 0` +/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) +/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) +/// - `scalar_i < 2^{max_bits} for all i` +/// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` pub fn scalar_multiply( chip: &FC, ctx: &mut Context, - P: &EcPoint, + P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, + scalar_is_safe: bool, ) -> EcPoint where - FC: FieldChip + Selectable, + FC: FieldChip + Selectable, { assert!(!scalar.is_empty()); assert!((max_bits as u64) <= modulus::().bits()); @@ -347,16 +541,16 @@ where cached_points.push(P.clone()); for idx in 2..cache_size { if idx == 2 { - let double = ec_double(chip, ctx, P /*, b*/); + let double = ec_double(chip, ctx, &P); cached_points.push(double); } else { - let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], P, false); + let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, !scalar_is_safe); cached_points.push(new_point); } } // if all the starting window bits are 0, get start_point = P - let mut curr_point = ec_select_from_bits::( + let mut curr_point = ec_select_from_bits( chip, ctx, &cached_points, @@ -366,39 +560,38 @@ where for idx in 1..num_windows { let mut mult_point = curr_point.clone(); for _ in 0..window_bits { - mult_point = ec_double(chip, ctx, &mult_point); + mult_point = ec_double(chip, ctx, mult_point); } - let add_point = ec_select_from_bits::( + let add_point = ec_select_from_bits( chip, ctx, &cached_points, &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, false); - let is_started_point = - ec_select(chip, ctx, &mult_point, &mult_and_add, is_zero_window[idx]); + let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, !scalar_is_safe); + let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]); curr_point = - ec_select(chip, ctx, &is_started_point, &add_point, is_started[window_bits * idx]); + ec_select(chip, ctx, is_started_point, add_point, is_started[window_bits * idx]); } curr_point } -pub fn is_on_curve(chip: &FC, ctx: &mut Context, P: &EcPoint) +/// Checks that `P` is indeed a point on the elliptic curve `C`. +pub fn check_is_on_curve(chip: &FC, ctx: &mut Context, P: &EcPoint) where F: PrimeField, FC: FieldChip, C: CurveAffine, { let lhs = chip.mul_no_carry(ctx, &P.y, &P.y); - let mut rhs = chip.mul(ctx, &P.x, &P.x); - rhs = chip.mul_no_carry(ctx, &rhs, &P.x); + let mut rhs = chip.mul(ctx, &P.x, &P.x).into(); + rhs = chip.mul_no_carry(ctx, rhs, &P.x); - let b = FC::fe_to_constant(C::b()); - rhs = chip.add_constant_no_carry(ctx, &rhs, b); - let diff = chip.sub_no_carry(ctx, &lhs, &rhs); - chip.check_carry_mod_to_zero(ctx, &diff) + rhs = chip.add_constant_no_carry(ctx, rhs, C::b()); + let diff = chip.sub_no_carry(ctx, lhs, rhs); + chip.check_carry_mod_to_zero(ctx, diff) } pub fn load_random_point(chip: &FC, ctx: &mut Context) -> EcPoint @@ -409,24 +602,45 @@ where { let base_point: C = C::CurveExt::random(ChaCha20Rng::from_entropy()).to_affine(); let (x, y) = base_point.into_coordinates(); - let pt_x = FC::fe_to_witness(&x); - let pt_y = FC::fe_to_witness(&y); let base = { - let x_overflow = chip.load_private(ctx, pt_x); - let y_overflow = chip.load_private(ctx, pt_y); - EcPoint::construct(x_overflow, y_overflow) + let x_overflow = chip.load_private(ctx, x); + let y_overflow = chip.load_private(ctx, y); + EcPoint::new(x_overflow, y_overflow) }; // for above reason we still need to constrain that the witness is on the curve - is_on_curve::(chip, ctx, &base); + check_is_on_curve::(chip, ctx, &base); base } +pub fn into_strict_point( + chip: &FC, + ctx: &mut Context, + pt: EcPoint, +) -> StrictEcPoint +where + F: PrimeField, + FC: FieldChip, +{ + let x = chip.enforce_less_than(ctx, pt.x); + StrictEcPoint::new(x, pt.y) +} + // need to supply an extra generic `C` implementing `CurveAffine` trait in order to generate random witness points on the curve in question // Using Simultaneous 2^w-Ary Method, see https://www.bmoeller.de/pdf/multiexp-sac2001.pdf // Random Accumlation point trick learned from halo2wrong: https://hackmd.io/ncuKqRXzR-Cw-Au2fGzsMg?view // Input: // - `scalars` is vector of same length as `P` // - each `scalar` in `scalars` satisfies same assumptions as in `scalar_multiply` above + +/// # Assumptions +/// * `points.len() == scalars.len()` +/// * `scalars[i].len() == scalars[j].len()` for all `i, j` +/// * `scalars[i]` is less than the order of `P` +/// * `scalars[i][j] < 2^{max_bits} for all j` +/// * `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` +/// * `points` are all on the curve or the point at infinity +/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) +/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point pub fn multi_scalar_multiply( chip: &FC, ctx: &mut Context, @@ -436,7 +650,7 @@ pub fn multi_scalar_multiply( window_bits: usize, ) -> EcPoint where - FC: FieldChip + Selectable, + FC: FieldChip + Selectable, C: CurveAffineExt, { let k = P.len(); @@ -463,7 +677,7 @@ where }) .collect_vec(); - // load random C point as witness + // load any sufficiently generic C point as witness // note that while we load a random point, an adversary would load a specifically chosen point, so we must carefully handle edge cases with constraints let base = load_random_point::(chip, ctx); // contains random base points [A, ..., 2^{w + k - 1} * A] @@ -488,19 +702,19 @@ where ctx, &rand_start_vec[idx], &rand_start_vec[idx + window_bits], - false, + true, // not necessary if we assume (2^w - 1) * A != +- A, but put in for safety ); - chip.enforce_less_than(ctx, point.x()); - chip.enforce_less_than(ctx, neg_mult_rand_start.x()); + let point = into_strict_point(chip, ctx, point.clone()); + let neg_mult_rand_start = into_strict_point(chip, ctx, neg_mult_rand_start); // cached_points[i][0..cache_size] stores (1 - 2^w) * 2^i * A + [0..cache_size] * P_i cached_points.push(neg_mult_rand_start); for _ in 0..(cache_size - 1) { - let prev = cached_points.last().unwrap(); + let prev = cached_points.last().unwrap().clone(); // adversary could pick `A` so add equal case occurs, so we must use strict add_unequal - let mut new_point = ec_add_unequal(chip, ctx, prev, point, true); + let mut new_point = ec_add_unequal(chip, ctx, &prev, &point, true); // special case for when P[idx] = O - new_point = ec_select(chip, ctx, prev, &new_point, is_infinity); - chip.enforce_less_than(ctx, new_point.x()); + new_point = ec_select(chip, ctx, prev.into(), new_point, is_infinity); + let new_point = into_strict_point(chip, ctx, new_point); cached_points.push(new_point); } } @@ -509,38 +723,35 @@ where // note k can be large (e.g., 800) so 2^{k+1} may be larger than the order of A // random fact: 2^{k + 1} - 1 can be prime: see Mersenne primes // TODO: I don't see a way to rule out 2^{k+1} A = +-A case in general, so will use strict sub_unequal - let start_point = if k < F::CAPACITY as usize { - ec_sub_unequal(chip, ctx, &rand_start_vec[k], &rand_start_vec[0], false) - } else { - chip.enforce_less_than(ctx, rand_start_vec[k].x()); - chip.enforce_less_than(ctx, rand_start_vec[0].x()); - ec_sub_unequal(chip, ctx, &rand_start_vec[k], &rand_start_vec[0], true) - }; + let start_point = ec_sub_unequal( + chip, + ctx, + &rand_start_vec[k], + &rand_start_vec[0], + k >= F::CAPACITY as usize, + ); let mut curr_point = start_point.clone(); // compute \sum_i x_i P_i + (2^{k + 1} - 1) * A for idx in 0..num_windows { for _ in 0..window_bits { - curr_point = ec_double(chip, ctx, &curr_point); + curr_point = ec_double(chip, ctx, curr_point); } for (cached_points, rounded_bits) in cached_points.chunks(cache_size).zip(rounded_bits.chunks(rounded_bitlen)) { - let add_point = ec_select_from_bits::( + let add_point = ec_select_from_bits( chip, ctx, cached_points, &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - chip.enforce_less_than(ctx, curr_point.x()); // this all needs strict add_unequal since A can be non-randomly chosen by adversary - curr_point = ec_add_unequal(chip, ctx, &curr_point, &add_point, true); + curr_point = ec_add_unequal(chip, ctx, curr_point, add_point, true); } } - chip.enforce_less_than(ctx, start_point.x()); - chip.enforce_less_than(ctx, curr_point.x()); - ec_sub_unequal(chip, ctx, &curr_point, &start_point, true) + ec_sub_strict(chip, ctx, curr_point, start_point) } pub fn get_naf(mut exp: Vec) -> Vec { @@ -608,26 +819,56 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { self.field_chip } - pub fn load_private( + /// Load affine point as private witness. Constrains witness to lie on curve. Does not allow (0, 0) point, + pub fn load_private( &self, ctx: &mut Context, - point: (FC::FieldType, FC::FieldType), - ) -> EcPoint { - let (x, y) = (FC::fe_to_witness(&point.0), FC::fe_to_witness(&point.1)); + (x, y): (FC::FieldType, FC::FieldType), + ) -> EcPoint + where + C: CurveAffineExt, + { + let pt = self.load_private_unchecked(ctx, (x, y)); + self.assert_is_on_curve::(ctx, &pt); + pt + } + /// Does not constrain witness to lie on curve + pub fn load_private_unchecked( + &self, + ctx: &mut Context, + (x, y): (FC::FieldType, FC::FieldType), + ) -> EcPoint { let x_assigned = self.field_chip.load_private(ctx, x); let y_assigned = self.field_chip.load_private(ctx, y); - EcPoint::construct(x_assigned, y_assigned) + EcPoint::new(x_assigned, y_assigned) } - /// Does not constrain witness to lie on curve + /// Load affine point as private witness. Constrains witness to either lie on curve or be the point at infinity, + /// represented in affine coordinates as (0, 0). pub fn assign_point(&self, ctx: &mut Context, g: C) -> EcPoint + where + C: CurveAffineExt, + C::Base: ff::PrimeField, + { + let pt = self.assign_point_unchecked(ctx, g); + let is_on_curve = self.is_on_curve_or_infinity::(ctx, &pt); + self.field_chip.gate().assert_is_const(ctx, &is_on_curve, &F::one()); + pt + } + + /// Does not constrain witness to lie on curve + pub fn assign_point_unchecked( + &self, + ctx: &mut Context, + g: C, + ) -> EcPoint where C: CurveAffineExt, { let (x, y) = g.into_coordinates(); - self.load_private(ctx, (x, y)) + self.load_private_unchecked(ctx, (x, y)) } pub fn assign_constant_point(&self, ctx: &mut Context, g: C) -> EcPoint @@ -635,11 +876,10 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { C: CurveAffineExt, { let (x, y) = g.into_coordinates(); - let [x, y] = [x, y].map(FC::fe_to_constant); let x = self.field_chip.load_constant(ctx, x); let y = self.field_chip.load_constant(ctx, y); - EcPoint::construct(x, y) + EcPoint::new(x, y) } pub fn load_random_point(&self, ctx: &mut Context) -> EcPoint @@ -653,7 +893,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { where C: CurveAffine, { - is_on_curve::(self.field_chip, ctx, P) + check_is_on_curve::(self.field_chip, ctx, P) } pub fn is_on_curve_or_infinity( @@ -663,18 +903,16 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { ) -> AssignedValue where C: CurveAffine, - C::Base: ff::PrimeField, { let lhs = self.field_chip.mul_no_carry(ctx, &P.y, &P.y); - let mut rhs = self.field_chip.mul(ctx, &P.x, &P.x); - rhs = self.field_chip.mul_no_carry(ctx, &rhs, &P.x); + let mut rhs = self.field_chip.mul(ctx, &P.x, &P.x).into(); + rhs = self.field_chip.mul_no_carry(ctx, rhs, &P.x); - let b = FC::fe_to_constant(C::b()); - rhs = self.field_chip.add_constant_no_carry(ctx, &rhs, b); - let mut diff = self.field_chip.sub_no_carry(ctx, &lhs, &rhs); - diff = self.field_chip.carry_mod(ctx, &diff); + rhs = self.field_chip.add_constant_no_carry(ctx, rhs, C::b()); + let diff = self.field_chip.sub_no_carry(ctx, lhs, rhs); + let diff = self.field_chip.carry_mod(ctx, diff); - let is_on_curve = self.field_chip.is_zero(ctx, &diff); + let is_on_curve = self.field_chip.is_zero(ctx, diff); let x_is_zero = self.field_chip.is_zero(ctx, &P.x); let y_is_zero = self.field_chip.is_zero(ctx, &P.y); @@ -685,9 +923,10 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn negate( &self, ctx: &mut Context, - P: &EcPoint, + P: impl Into>, ) -> EcPoint { - EcPoint::construct(P.x.clone(), self.field_chip.negate(ctx, &P.y)) + let P = P.into(); + EcPoint::new(P.x, self.field_chip.negate(ctx, P.y)) } /// Assumes that P.x != Q.x @@ -695,8 +934,8 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn add_unequal( &self, ctx: &mut Context, - P: &EcPoint, - Q: &EcPoint, + P: impl Into>, + Q: impl Into>, is_strict: bool, ) -> EcPoint { ec_add_unequal(self.field_chip, ctx, P, Q, is_strict) @@ -707,8 +946,8 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn sub_unequal( &self, ctx: &mut Context, - P: &EcPoint, - Q: &EcPoint, + P: impl Into>, + Q: impl Into>, is_strict: bool, ) -> EcPoint { ec_sub_unequal(self.field_chip, ctx, P, Q, is_strict) @@ -717,7 +956,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn double( &self, ctx: &mut Context, - P: &EcPoint, + P: impl Into>, ) -> EcPoint { ec_double(self.field_chip, ctx, P) } @@ -725,72 +964,82 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn is_equal( &self, ctx: &mut Context, - P: &EcPoint, - Q: &EcPoint, + P: EcPoint, + Q: EcPoint, ) -> AssignedValue { // TODO: optimize - let x_is_equal = self.field_chip.is_equal(ctx, &P.x, &Q.x); - let y_is_equal = self.field_chip.is_equal(ctx, &P.y, &Q.y); + let x_is_equal = self.field_chip.is_equal(ctx, P.x, Q.x); + let y_is_equal = self.field_chip.is_equal(ctx, P.y, Q.y); self.field_chip.range().gate().and(ctx, x_is_equal, y_is_equal) } pub fn assert_equal( &self, ctx: &mut Context, - P: &EcPoint, - Q: &EcPoint, + P: EcPoint, + Q: EcPoint, ) { - self.field_chip.assert_equal(ctx, &P.x, &Q.x); - self.field_chip.assert_equal(ctx, &P.y, &Q.y); + self.field_chip.assert_equal(ctx, P.x, Q.x); + self.field_chip.assert_equal(ctx, P.y, Q.y); } - pub fn sum<'b, 'v: 'b, C>( + /// None of elements in `points` can be point at infinity. + pub fn sum( &self, ctx: &mut Context, - points: impl Iterator>, + points: impl IntoIterator>, ) -> EcPoint where C: CurveAffineExt, - FC::FieldPoint: 'b, { let rand_point = self.load_random_point::(ctx); - self.field_chip.enforce_less_than(ctx, rand_point.x()); + let rand_point = into_strict_point(self.field_chip, ctx, rand_point); let mut acc = rand_point.clone(); for point in points { - self.field_chip.enforce_less_than(ctx, point.x()); - acc = self.add_unequal(ctx, &acc, point, true); - self.field_chip.enforce_less_than(ctx, acc.x()); + let _acc = self.add_unequal(ctx, acc, point, true); + acc = into_strict_point(self.field_chip, ctx, _acc); } - self.sub_unequal(ctx, &acc, &rand_point, true) + self.sub_unequal(ctx, acc, rand_point, true) } } impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> where - FC: Selectable, + FC: Selectable, { pub fn select( &self, ctx: &mut Context, - P: &EcPoint, - Q: &EcPoint, + P: EcPoint, + Q: EcPoint, condition: AssignedValue, ) -> EcPoint { ec_select(self.field_chip, ctx, P, Q, condition) } + /// See [`scalar_multiply`] for more details. pub fn scalar_mult( &self, ctx: &mut Context, - P: &EcPoint, + P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, + scalar_is_safe: bool, ) -> EcPoint { - scalar_multiply::(self.field_chip, ctx, P, scalar, max_bits, window_bits) + scalar_multiply::( + self.field_chip, + ctx, + P, + scalar, + max_bits, + window_bits, + scalar_is_safe, + ) } // default for most purposes + /// See [`pippenger::multi_exp_par`] for more details. pub fn variable_base_msm( &self, thread_pool: &mut GateThreadBuilder, @@ -801,6 +1050,7 @@ where where C: CurveAffineExt, C::Base: ff::PrimeField, + FC: Selectable, { // window_bits = 4 is optimal from empirical observations self.variable_base_msm_in::(thread_pool, P, scalars, max_bits, 4, 0) @@ -819,6 +1069,7 @@ where where C: CurveAffineExt, C::Base: ff::PrimeField, + FC: Selectable, { #[cfg(feature = "display")] println!("computing length {} MSM", P.len()); @@ -854,10 +1105,8 @@ where } } -impl<'chip, F: PrimeField, FC: PrimeFieldChip> EccChip<'chip, F, FC> -where - FC::FieldType: PrimeField, -{ +impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { + /// See [`fixed_base::scalar_multiply`] for more details. // TODO: put a check in place that scalar is < modulus of C::Scalar pub fn fixed_base_scalar_mult( &self, @@ -866,11 +1115,11 @@ where scalar: Vec>, max_bits: usize, window_bits: usize, + scalar_is_safe: bool, ) -> EcPoint where C: CurveAffineExt, - FC: PrimeFieldChip> - + Selectable, + FC: FieldChip + Selectable, { fixed_base::scalar_multiply::( self.field_chip, @@ -879,6 +1128,7 @@ where scalar, max_bits, window_bits, + scalar_is_safe, ) } @@ -892,8 +1142,7 @@ where ) -> EcPoint where C: CurveAffineExt, - FC: PrimeFieldChip> - + Selectable, + FC: FieldChip + Selectable, { self.fixed_base_msm_in::(builder, points, scalars, max_scalar_bits_per_cell, 4, 0) } @@ -914,28 +1163,21 @@ where ) -> EcPoint where C: CurveAffineExt, - FC: PrimeFieldChip> - + Selectable, + FC: FieldChip + Selectable, { - debug_assert_eq!(points.len(), scalars.len()); + assert_eq!(points.len(), scalars.len()); #[cfg(feature = "display")] println!("computing length {} fixed base msm", points.len()); - // heuristic to decide when to use parallelism - if points.len() < 25 { - let ctx = builder.main(phase); - fixed_base::msm(self, ctx, points, scalars, max_scalar_bits_per_cell, clump_factor) - } else { - fixed_base::msm_par( - self, - builder, - points, - scalars, - max_scalar_bits_per_cell, - clump_factor, - phase, - ) - } + fixed_base::msm_par( + self, + builder, + points, + scalars, + max_scalar_bits_per_cell, + clump_factor, + phase, + ) // Empirically does not seem like pippenger is any better for fixed base msm right now, because of the cost of `select_by_indicator` // Cell usage becomes around comparable when `points.len() > 100`, and `clump_factor` should always be 4 diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index 58082c37..934a7432 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -1,14 +1,19 @@ use super::{ - ec_add_unequal, ec_double, ec_select, ec_select_from_bits, ec_sub_unequal, load_random_point, - EcPoint, + ec_add_unequal, ec_double, ec_select, ec_sub_unequal, into_strict_point, load_random_point, + strict_ec_select_from_bits, EcPoint, +}; +use crate::{ + ecc::ec_sub_strict, + fields::{FieldChip, PrimeField, Selectable}, }; -use crate::fields::{FieldChip, PrimeField, Selectable}; use halo2_base::{ - gates::{builder::GateThreadBuilder, GateInstructions}, + gates::{ + builder::{parallelize_in, GateThreadBuilder}, + GateInstructions, + }, utils::CurveAffineExt, - AssignedValue, Context, + AssignedValue, }; -use rayon::prelude::*; // Reference: https://jbootle.github.io/Misc/pippenger.pdf @@ -64,6 +69,7 @@ where } */ +/* Left as reference; should always use msm_par // Given points[i] and bool_scalars[j][i], // compute G'[j] = sum_{i=0..points.len()} points[i] * bool_scalars[j][i] // output is [ G'[j] + rand_point ]_{j=0..bool_scalars.len()}, rand_point @@ -73,9 +79,9 @@ pub fn multi_product( points: &[EcPoint], bool_scalars: &[Vec>], clumping_factor: usize, -) -> (Vec>, EcPoint) +) -> (Vec>, EcPoint) where - FC: FieldChip + Selectable, + FC: FieldChip + Selectable + Selectable, C: CurveAffineExt, { let c = clumping_factor; // this is `b` in Section 3 of Bootle @@ -84,66 +90,71 @@ where // we use a trick from halo2wrong where we load a random C point as witness // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints // TODO: an alternate approach is to use Fiat-Shamir transform (with Poseidon) to hash all the inputs (points, bool_scalars, ...) to get the random point. This could be worth it for large MSMs as we get savings from `add_unequal` in "non-strict" mode. Perhaps not worth the trouble / security concern, though. - let rand_base = load_random_point::(chip, ctx); + let any_base = load_random_point::(chip, ctx); let mut acc = Vec::with_capacity(bool_scalars.len()); let mut bucket = Vec::with_capacity(1 << c); - let mut rand_point = rand_base.clone(); + let mut any_point = any_base.clone(); for (round, points_clump) in points.chunks(c).enumerate() { // compute all possible multi-products of elements in points[round * c .. round * (c+1)] // for later addition collision-prevension, we need a different random point per round // we take 2^round * rand_base if round > 0 { - rand_point = ec_double(chip, ctx, &rand_point); + any_point = ec_double(chip, ctx, any_point); } // stores { rand_point, rand_point + points[0], rand_point + points[1], rand_point + points[0] + points[1] , ... } // since rand_point is random, we can always use add_unequal (with strict constraint checking that the points are indeed unequal and not negative of each other) bucket.clear(); - chip.enforce_less_than(ctx, rand_point.x()); - bucket.push(rand_point.clone()); + let strict_any_point = into_strict_point(chip, ctx, any_point.clone()); + bucket.push(strict_any_point); for (i, point) in points_clump.iter().enumerate() { // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates // this can be checked by points[i].y == 0 iff points[i] == O let is_infinity = chip.is_zero(ctx, &point.y); - chip.enforce_less_than(ctx, point.x()); + let point = into_strict_point(chip, ctx, point.clone()); for j in 0..(1 << i) { - let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], point, true); + let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true); // if points[i] is point at infinity, do nothing - new_point = ec_select(chip, ctx, &bucket[j], &new_point, is_infinity); - chip.enforce_less_than(ctx, new_point.x()); + new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity); + let new_point = into_strict_point(chip, ctx, new_point); bucket.push(new_point); } } // for each j, select using clump in e[j][i=...] for (j, bits) in bool_scalars.iter().enumerate() { - let multi_prod = ec_select_from_bits::( + let multi_prod = strict_ec_select_from_bits( chip, ctx, &bucket, &bits[round * c..round * c + points_clump.len()], ); + // since `bucket` is all `StrictEcPoint` and we are selecting from it, we know `multi_prod` is StrictEcPoint // everything in bucket has already been enforced if round == 0 { acc.push(multi_prod); } else { - acc[j] = ec_add_unequal(chip, ctx, &acc[j], &multi_prod, true); - chip.enforce_less_than(ctx, acc[j].x()); + let _acc = ec_add_unequal(chip, ctx, &acc[j], multi_prod, true); + acc[j] = into_strict_point(chip, ctx, _acc); } } } // we have acc[j] = G'[j] + (2^num_rounds - 1) * rand_base - rand_point = ec_double(chip, ctx, &rand_point); - rand_point = ec_sub_unequal(chip, ctx, &rand_point, &rand_base, false); + any_point = ec_double(chip, ctx, any_point); + any_point = ec_sub_unequal(chip, ctx, any_point, any_base, false); - (acc, rand_point) + (acc, any_point) } -/// Currently does not support if the final answer is actually the point at infinity +/// Currently does not support if the final answer is actually the point at infinity (meaning constraints will fail in that case) +/// +/// # Assumptions +/// * `points.len() == scalars.len()` +/// * `scalars[i].len() == scalars[j].len()` for all `i, j` pub fn multi_exp( chip: &FC, ctx: &mut Context, @@ -154,7 +165,7 @@ pub fn multi_exp( clump_factor: usize, ) -> EcPoint where - FC: FieldChip + Selectable, + FC: FieldChip + Selectable + Selectable, C: CurveAffineExt, { // let (points, bool_scalars) = decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); @@ -174,35 +185,37 @@ where } } - let (mut agg, rand_point) = + let (mut agg, any_point) = multi_product::(chip, ctx, points, &bool_scalars, clump_factor); // everything in agg has been enforced // compute sum_{k=0..t} agg[k] * 2^{radix * k} - (sum_k 2^{radix * k}) * rand_point // (sum_{k=0..t} 2^{radix * k}) = (2^{radix * t} - 1)/(2^radix - 1) - let mut sum = agg.pop().unwrap(); - let mut rand_sum = rand_point.clone(); + let mut sum = agg.pop().unwrap().into(); + let mut any_sum = any_point.clone(); for g in agg.iter().rev() { - rand_sum = ec_double(chip, ctx, &rand_sum); + any_sum = ec_double(chip, ctx, any_sum); // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g` - sum = ec_double(chip, ctx, &sum); - chip.enforce_less_than(ctx, sum.x()); - sum = ec_add_unequal(chip, ctx, &sum, g, true); + sum = ec_double(chip, ctx, sum); + sum = ec_add_unequal(chip, ctx, sum, g, true); } - rand_sum = ec_double(chip, ctx, &rand_sum); + any_sum = ec_double(chip, ctx, any_sum); // assume 2^scalar_bits != +-1 mod modulus::() - rand_sum = ec_sub_unequal(chip, ctx, &rand_sum, &rand_point, false); + any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, false); - chip.enforce_less_than(ctx, sum.x()); - chip.enforce_less_than(ctx, rand_sum.x()); - ec_sub_unequal(chip, ctx, &sum, &rand_sum, true) + ec_sub_unequal(chip, ctx, sum, any_sum, true) } +*/ /// Multi-thread witness generation for multi-scalar multiplication. -/// Should give exact same circuit as `multi_exp`. /// -/// Currently does not support if the final answer is actually the point at infinity +/// # Assumptions +/// * `points.len() == scalars.len()` +/// * `scalars[i].len() == scalars[j].len()` for all `i, j` +/// * `points` are all on the curve or the point at infinity +/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) +/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point pub fn multi_exp_par( chip: &FC, // these are the "threads" within a single Phase @@ -215,19 +228,18 @@ pub fn multi_exp_par( phase: usize, ) -> EcPoint where - FC: FieldChip + Selectable, + FC: FieldChip + Selectable + Selectable, C: CurveAffineExt, { // let (points, bool_scalars) = decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); - debug_assert_eq!(points.len(), scalars.len()); + assert_eq!(points.len(), scalars.len()); let scalar_bits = max_scalar_bits_per_cell * scalars[0].len(); // bool_scalars: 2d array `scalar_bits` by `points.len()` let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits]; // get a main thread let ctx = builder.main(phase); - let witness_gen_only = ctx.witness_gen_only(); // single-threaded computation: for scalar in scalars { for (scalar_chunk, bool_chunk) in @@ -239,117 +251,91 @@ where } } } - // see multi-product comments for explanation of below let c = clump_factor; let num_rounds = (points.len() + c - 1) / c; - let rand_base = load_random_point::(chip, ctx); - let mut rand_points = Vec::with_capacity(num_rounds); - rand_points.push(rand_base); + // to avoid adding two points that are equal or negative of each other, + // we use a trick from halo2wrong where we load a "sufficiently generic" `C` point as witness + // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints + // we call it "any point" instead of "random point" to emphasize that "any" sufficiently generic point will do + let any_base = load_random_point::(chip, ctx); + let mut any_points = Vec::with_capacity(num_rounds); + any_points.push(any_base); for _ in 1..num_rounds { - rand_points.push(ec_double(chip, ctx, rand_points.last().unwrap())); + any_points.push(ec_double(chip, ctx, any_points.last().unwrap())); } - // we will use a different thread per round - // to prevent concurrency issues with context id, we generate all the ids first - let thread_ids = (0..num_rounds).map(|_| builder.get_new_thread_id()).collect::>(); - // now begins multi-threading + // now begins multi-threading // multi_prods is 2d vector of size `num_rounds` by `scalar_bits` - let (new_threads, multi_prods): (Vec<_>, Vec<_>) = points - .par_chunks(c) - .zip(rand_points.par_iter()) - .zip(thread_ids.into_par_iter()) - .enumerate() - .map(|(round, ((points_clump, rand_point), thread_id))| { + let multi_prods = parallelize_in( + phase, + builder, + points.chunks(c).into_iter().zip(any_points.iter()).enumerate().collect(), + |ctx, (round, (points_clump, any_point))| { // compute all possible multi-products of elements in points[round * c .. round * (c+1)] - // create new thread - let mut thread = Context::new(witness_gen_only, thread_id); - let ctx = &mut thread; - // stores { rand_point, rand_point + points[0], rand_point + points[1], rand_point + points[0] + points[1] , ... } + // stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... } let mut bucket = Vec::with_capacity(1 << c); - chip.enforce_less_than(ctx, rand_point.x()); - bucket.push(rand_point.clone()); + let any_point = into_strict_point(chip, ctx, any_point.clone()); + bucket.push(any_point); for (i, point) in points_clump.iter().enumerate() { // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates // this can be checked by points[i].y == 0 iff points[i] == O let is_infinity = chip.is_zero(ctx, &point.y); - chip.enforce_less_than(ctx, point.x()); + let point = into_strict_point(chip, ctx, point.clone()); for j in 0..(1 << i) { - let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], point, true); + let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true); // if points[i] is point at infinity, do nothing - new_point = ec_select(chip, ctx, &bucket[j], &new_point, is_infinity); - chip.enforce_less_than(ctx, new_point.x()); + new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity); + let new_point = into_strict_point(chip, ctx, new_point); bucket.push(new_point); } } - let multi_prods = bool_scalars + bool_scalars .iter() .map(|bits| { - ec_select_from_bits::( + strict_ec_select_from_bits( chip, ctx, &bucket, &bits[round * c..round * c + points_clump.len()], ) }) - .collect::>(); - - (thread, multi_prods) - }) - .unzip(); - // we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused - builder.threads[phase].extend(new_threads); + .collect::>() + }, + ); // agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits - let thread_ids = (0..scalar_bits).map(|_| builder.get_new_thread_id()).collect::>(); - let (new_threads, mut agg): (Vec<_>, Vec<_>) = thread_ids - .into_par_iter() - .enumerate() - .map(|(i, thread_id)| { - let mut thread = Context::new(witness_gen_only, thread_id); - let ctx = &mut thread; - let mut acc = if multi_prods.len() == 1 { - multi_prods[0][i].clone() - } else { - ec_add_unequal(chip, ctx, &multi_prods[0][i], &multi_prods[1][i], true) - }; - chip.enforce_less_than(ctx, acc.x()); - for multi_prod in multi_prods.iter().skip(2) { - acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true); - chip.enforce_less_than(ctx, acc.x()); - } - (thread, acc) - }) - .unzip(); - builder.threads[phase].extend(new_threads); + let mut agg = parallelize_in(phase, builder, (0..scalar_bits).collect(), |ctx, i| { + let mut acc = multi_prods[0][i].clone(); + for multi_prod in multi_prods.iter().skip(1) { + let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true); + acc = into_strict_point(chip, ctx, _acc); + } + acc + }); // gets the LAST thread for single threaded work - // warning: don't get any earlier threads, because currently we assume equality constraints in thread i only involves threads <= i let ctx = builder.main(phase); - // we have agg[j] = G'[j] + (2^num_rounds - 1) * rand_base - // let rand_point = (2^num_rounds - 1) * rand_base + // we have agg[j] = G'[j] + (2^num_rounds - 1) * any_base + // let any_point = (2^num_rounds - 1) * any_base // TODO: can we remove all these random point operations somehow? - let mut rand_point = ec_double(chip, ctx, rand_points.last().unwrap()); - rand_point = ec_sub_unequal(chip, ctx, &rand_point, &rand_points[0], false); + let mut any_point = ec_double(chip, ctx, any_points.last().unwrap()); + any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], true); // compute sum_{k=0..scalar_bits} agg[k] * 2^k - (sum_{k=0..scalar_bits} 2^k) * rand_point // (sum_{k=0..scalar_bits} 2^k) = (2^scalar_bits - 1) - let mut sum = agg.pop().unwrap(); - let mut rand_sum = rand_point.clone(); + let mut sum = agg.pop().unwrap().into(); + let mut any_sum = any_point.clone(); for g in agg.iter().rev() { - rand_sum = ec_double(chip, ctx, &rand_sum); + any_sum = ec_double(chip, ctx, any_sum); // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g` - sum = ec_double(chip, ctx, &sum); - chip.enforce_less_than(ctx, sum.x()); - sum = ec_add_unequal(chip, ctx, &sum, g, true); + sum = ec_double(chip, ctx, sum); + sum = ec_add_unequal(chip, ctx, sum, g, true); } - rand_sum = ec_double(chip, ctx, &rand_sum); - // assume 2^scalar_bits != +-1 mod modulus::() - rand_sum = ec_sub_unequal(chip, ctx, &rand_sum, &rand_point, false); + any_sum = ec_double(chip, ctx, any_sum); + any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, true); - chip.enforce_less_than(ctx, sum.x()); - chip.enforce_less_than(ctx, rand_sum.x()); - ec_sub_unequal(chip, ctx, &sum, &rand_sum, true) + ec_sub_strict(chip, ctx, sum, any_sum) } diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index fb9d7abf..5bbc612e 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -31,30 +31,30 @@ fn basic_g1_tests( let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); let chip = EccChip::new(&fp_chip); - let P_assigned = chip.load_private(ctx, (P.x, P.y)); - let Q_assigned = chip.load_private(ctx, (Q.x, Q.y)); + let P_assigned = chip.load_private_unchecked(ctx, (P.x, P.y)); + let Q_assigned = chip.load_private_unchecked(ctx, (Q.x, Q.y)); // test add_unequal - chip.field_chip.enforce_less_than(ctx, P_assigned.x()); - chip.field_chip.enforce_less_than(ctx, Q_assigned.x()); + chip.field_chip.enforce_less_than(ctx, P_assigned.x().clone()); + chip.field_chip.enforce_less_than(ctx, Q_assigned.x().clone()); let sum = chip.add_unequal(ctx, &P_assigned, &Q_assigned, false); - assert_eq!(sum.x.truncation.to_bigint(limb_bits), sum.x.value); - assert_eq!(sum.y.truncation.to_bigint(limb_bits), sum.y.value); + assert_eq!(sum.x.0.truncation.to_bigint(limb_bits), sum.x.0.value); + assert_eq!(sum.y.0.truncation.to_bigint(limb_bits), sum.y.0.value); { let actual_sum = G1Affine::from(P + Q); - assert_eq!(bigint_to_fe::(&sum.x.value), actual_sum.x); - assert_eq!(bigint_to_fe::(&sum.y.value), actual_sum.y); + assert_eq!(bigint_to_fe::(&sum.x.0.value), actual_sum.x); + assert_eq!(bigint_to_fe::(&sum.y.0.value), actual_sum.y); } println!("add unequal witness OK"); // test double let doub = chip.double(ctx, &P_assigned); - assert_eq!(doub.x.truncation.to_bigint(limb_bits), doub.x.value); - assert_eq!(doub.y.truncation.to_bigint(limb_bits), doub.y.value); + assert_eq!(doub.x.0.truncation.to_bigint(limb_bits), doub.x.0.value); + assert_eq!(doub.y.0.truncation.to_bigint(limb_bits), doub.y.0.value); { let actual_doub = G1Affine::from(P * Fr::from(2u64)); - assert_eq!(bigint_to_fe::(&doub.x.value), actual_doub.x); - assert_eq!(bigint_to_fe::(&doub.y.value), actual_doub.y); + assert_eq!(bigint_to_fe::(&doub.x.0.value), actual_doub.x); + assert_eq!(bigint_to_fe::(&doub.y.0.value), actual_doub.y); } println!("double witness OK"); } diff --git a/halo2-ecc/src/fields/fp.rs b/halo2-ecc/src/fields/fp.rs index 6099a147..97bfd8b3 100644 --- a/halo2-ecc/src/fields/fp.rs +++ b/halo2-ecc/src/fields/fp.rs @@ -2,11 +2,11 @@ use super::{FieldChip, PrimeField, PrimeFieldChip, Selectable}; use crate::bigint::{ add_no_carry, big_is_equal, big_is_zero, carry_mod, check_carry_mod_to_zero, mul_no_carry, scalar_mul_and_add_no_carry, scalar_mul_no_carry, select, select_by_indicator, sub, - sub_no_carry, CRTInteger, FixedCRTInteger, OverflowInteger, + sub_no_carry, CRTInteger, FixedCRTInteger, OverflowInteger, ProperCrtUint, ProperUint, }; use crate::halo2_proofs::halo2curves::CurveAffine; use halo2_base::gates::RangeChip; -use halo2_base::utils::decompose_bigint; +use halo2_base::utils::ScalarField; use halo2_base::{ gates::{range::RangeConfig, GateInstructions, RangeInstructions}, utils::{bigint_to_fe, biguint_to_fe, bit_length, decompose_biguint, fe_to_biguint, modulus}, @@ -22,6 +22,28 @@ pub type BaseFieldChip<'range, C> = pub type FpConfig = RangeConfig; +/// Wrapper around `FieldPoint` to guarantee this is a "reduced" representation of an `Fp` field element. +/// A reduced representation guarantees that there is a *unique* representation of each field element. +/// Typically this means Uints that are less than the modulus. +#[derive(Clone, Debug)] +pub struct Reduced(pub(crate) FieldPoint, PhantomData); + +impl Reduced { + pub fn as_ref(&self) -> Reduced<&FieldPoint, Fp> { + Reduced(&self.0, PhantomData) + } + + pub fn inner(&self) -> &FieldPoint { + &self.0 + } +} + +impl From, Fp>> for ProperCrtUint { + fn from(x: Reduced, Fp>) -> Self { + x.0 + } +} + // `Fp` always needs to be `BigPrimeField`, we may later want support for `F` being just `ScalarField` but for optimization reasons we'll assume it's also `BigPrimeField` for now #[derive(Clone, Debug)] @@ -47,6 +69,9 @@ pub struct FpChip<'range, F: PrimeField, Fp: PrimeField> { impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { pub fn new(range: &'range RangeChip, limb_bits: usize, num_limbs: usize) -> Self { + assert!(limb_bits > 0); + assert!(num_limbs > 0); + assert!(limb_bits <= F::CAPACITY as usize); let limb_mask = (BigUint::from(1u64) << limb_bits) - 1usize; let p = modulus::(); let p_limbs = decompose_biguint(&p, num_limbs, limb_bits); @@ -77,14 +102,14 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { } } - pub fn enforce_less_than_p(&self, ctx: &mut Context, a: &CRTInteger) { + pub fn enforce_less_than_p(&self, ctx: &mut Context, a: ProperCrtUint) { // a < p iff a - p has underflow let mut borrow: Option> = None; - for (&p_limb, &a_limb) in self.p_limbs.iter().zip(a.truncation.limbs.iter()) { + for (&p_limb, a_limb) in self.p_limbs.iter().zip(a.0.truncation.limbs) { let lt = match borrow { None => self.range.is_less_than(ctx, a_limb, Constant(p_limb), self.limb_bits), Some(borrow) => { - let plus_borrow = self.range.gate.add(ctx, Constant(p_limb), borrow); + let plus_borrow = self.gate().add(ctx, Constant(p_limb), borrow); self.range.is_less_than( ctx, Existing(a_limb), @@ -95,7 +120,15 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { }; borrow = Some(lt); } - self.range.gate.assert_is_const(ctx, &borrow.unwrap(), &F::one()); + self.gate().assert_is_const(ctx, &borrow.unwrap(), &F::one()); + } + + pub fn load_constant_uint(&self, ctx: &mut Context, a: BigUint) -> ProperCrtUint { + FixedCRTInteger::from_native(a, self.num_limbs, self.limb_bits).assign( + ctx, + self.limb_bits, + self.native_modulus(), + ) } } @@ -113,9 +146,9 @@ impl<'range, F: PrimeField, Fp: PrimeField> PrimeFieldChip for FpChip<'range, impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, Fp> { const PRIME_FIELD_NUM_BITS: u32 = Fp::NUM_BITS; - type ConstantType = BigUint; - type WitnessType = BigInt; - type FieldPoint = CRTInteger; + type UnsafeFieldPoint = CRTInteger; + type FieldPoint = ProperCrtUint; + type ReducedFieldPoint = Reduced, Fp>; type FieldType = Fp; type RangeChip = RangeChip; @@ -133,135 +166,110 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F bigint_to_fe(&(&x.value % &self.p)) } - fn fe_to_constant(x: Fp) -> BigUint { - fe_to_biguint(&x) - } - - fn fe_to_witness(x: &Fp) -> BigInt { - BigInt::from(fe_to_biguint(x)) - } - - fn load_private(&self, ctx: &mut Context, a: BigInt) -> CRTInteger { - let a_vec = decompose_bigint::(&a, self.num_limbs, self.limb_bits); + fn load_private(&self, ctx: &mut Context, a: Fp) -> ProperCrtUint { + let a = fe_to_biguint(&a); + let a_vec = decompose_biguint::(&a, self.num_limbs, self.limb_bits); let limbs = ctx.assign_witnesses(a_vec); - let a_native = OverflowInteger::::evaluate( - self.range.gate(), - ctx, - limbs.iter().copied(), - self.limb_bases.iter().copied(), - ); - let a_loaded = - CRTInteger::construct(OverflowInteger::construct(limbs, self.limb_bits), a_native, a); + ProperUint(limbs).into_crt(ctx, self.gate(), a, &self.limb_bases, self.limb_bits); - // TODO: this range check prevents loading witnesses that are not in "proper" representation form, is that ok? - self.range_check(ctx, &a_loaded, Self::PRIME_FIELD_NUM_BITS as usize); + self.range_check(ctx, a_loaded.clone(), Self::PRIME_FIELD_NUM_BITS as usize); a_loaded } - fn load_constant(&self, ctx: &mut Context, a: BigUint) -> CRTInteger { - let a_native = ctx.load_constant(biguint_to_fe(&(&a % self.native_modulus()))); - let a_limbs = decompose_biguint::(&a, self.num_limbs, self.limb_bits) - .into_iter() - .map(|c| ctx.load_constant(c)) - .collect(); - - CRTInteger::construct( - OverflowInteger::construct(a_limbs, self.limb_bits), - a_native, - BigInt::from(a), - ) + fn load_constant(&self, ctx: &mut Context, a: Fp) -> ProperCrtUint { + self.load_constant_uint(ctx, fe_to_biguint(&a)) } // signed overflow BigInt functions fn add_no_carry( &self, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: impl Into>, + b: impl Into>, ) -> CRTInteger { - add_no_carry::crt::(self.range.gate(), ctx, a, b) + add_no_carry::crt(self.gate(), ctx, a.into(), b.into()) } fn add_constant_no_carry( &self, ctx: &mut Context, - a: &CRTInteger, - c: BigUint, + a: impl Into>, + c: Fp, ) -> CRTInteger { - let c = FixedCRTInteger::from_native(c, self.num_limbs, self.limb_bits); + let c = FixedCRTInteger::from_native(fe_to_biguint(&c), self.num_limbs, self.limb_bits); let c_native = biguint_to_fe::(&(&c.value % modulus::())); + let a = a.into(); let mut limbs = Vec::with_capacity(a.truncation.limbs.len()); - for (a_limb, c_limb) in a.truncation.limbs.iter().zip(c.truncation.limbs.into_iter()) { - let limb = self.range.gate.add(ctx, *a_limb, Constant(c_limb)); + for (a_limb, c_limb) in a.truncation.limbs.into_iter().zip(c.truncation.limbs) { + let limb = self.gate().add(ctx, a_limb, Constant(c_limb)); limbs.push(limb); } - let native = self.range.gate.add(ctx, a.native, Constant(c_native)); + let native = self.gate().add(ctx, a.native, Constant(c_native)); let trunc = - OverflowInteger::construct(limbs, max(a.truncation.max_limb_bits, self.limb_bits) + 1); - let value = &a.value + BigInt::from(c.value); + OverflowInteger::new(limbs, max(a.truncation.max_limb_bits, self.limb_bits) + 1); + let value = a.value + BigInt::from(c.value); - CRTInteger::construct(trunc, native, value) + CRTInteger::new(trunc, native, value) } fn sub_no_carry( &self, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: impl Into>, + b: impl Into>, ) -> CRTInteger { - sub_no_carry::crt::(self.range.gate(), ctx, a, b) + sub_no_carry::crt::(self.gate(), ctx, a.into(), b.into()) } // Input: a // Output: p - a if a != 0, else a // Assume the actual value of `a` equals `a.truncation` // Constrains a.truncation <= p using subtraction with carries - fn negate(&self, ctx: &mut Context, a: &CRTInteger) -> CRTInteger { + fn negate(&self, ctx: &mut Context, a: ProperCrtUint) -> ProperCrtUint { // Compute p - a.truncation using carries - let p = self.load_constant(ctx, self.p.to_biguint().unwrap()); + let p = self.load_constant_uint(ctx, self.p.to_biguint().unwrap()); let (out_or_p, underflow) = - sub::crt::(self.range(), ctx, &p, a, self.limb_bits, self.limb_bases[1]); + sub::crt(self.range(), ctx, p, a.clone(), self.limb_bits, self.limb_bases[1]); // constrain underflow to equal 0 - self.range.gate.assert_is_const(ctx, &underflow, &F::zero()); + self.gate().assert_is_const(ctx, &underflow, &F::zero()); - let a_is_zero = big_is_zero::assign::(self.gate(), ctx, &a.truncation); - select::crt::(self.range.gate(), ctx, a, &out_or_p, a_is_zero) + let a_is_zero = big_is_zero::positive(self.gate(), ctx, a.0.truncation.clone()); + ProperCrtUint(select::crt(self.gate(), ctx, a.0, out_or_p, a_is_zero)) } fn scalar_mul_no_carry( &self, ctx: &mut Context, - a: &CRTInteger, + a: impl Into>, c: i64, ) -> CRTInteger { - scalar_mul_no_carry::crt::(self.range.gate(), ctx, a, c) + scalar_mul_no_carry::crt(self.gate(), ctx, a.into(), c) } fn scalar_mul_and_add_no_carry( &self, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: impl Into>, + b: impl Into>, c: i64, ) -> CRTInteger { - scalar_mul_and_add_no_carry::crt::(self.range.gate(), ctx, a, b, c) + scalar_mul_and_add_no_carry::crt(self.gate(), ctx, a.into(), b.into(), c) } fn mul_no_carry( &self, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: impl Into>, + b: impl Into>, ) -> CRTInteger { - mul_no_carry::crt::(self.range.gate(), ctx, a, b, self.num_limbs_log2_ceil) + mul_no_carry::crt(self.gate(), ctx, a.into(), b.into(), self.num_limbs_log2_ceil) } - fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: &CRTInteger) { + fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: CRTInteger) { check_carry_mod_to_zero::crt::( self.range(), - // &self.bigint_chip, ctx, a, self.num_limbs_bits, @@ -274,10 +282,9 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F ) } - fn carry_mod(&self, ctx: &mut Context, a: &CRTInteger) -> CRTInteger { + fn carry_mod(&self, ctx: &mut Context, a: CRTInteger) -> ProperCrtUint { carry_mod::crt::( self.range(), - // &self.bigint_chip, ctx, a, self.num_limbs_bits, @@ -290,109 +297,177 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F ) } + /// # Assumptions + /// * `max_bits` in `(n * (k - 1), n * k]` fn range_check( &self, ctx: &mut Context, - a: &CRTInteger, + a: impl Into>, max_bits: usize, // the maximum bits that a.value could take ) { let n = self.limb_bits; + let a = a.into(); let k = a.truncation.limbs.len(); debug_assert!(max_bits > n * (k - 1) && max_bits <= n * k); let last_limb_bits = max_bits - n * (k - 1); - #[cfg(debug_assertions)] debug_assert!(a.value.bits() as usize <= max_bits); // range check limbs of `a` are in [0, 2^n) except last limb should be in [0, 2^last_limb_bits) - for (i, cell) in a.truncation.limbs.iter().enumerate() { + for (i, cell) in a.truncation.limbs.into_iter().enumerate() { let limb_bits = if i == k - 1 { last_limb_bits } else { n }; - self.range.range_check(ctx, *cell, limb_bits); + self.range.range_check(ctx, cell, limb_bits); } } - fn enforce_less_than(&self, ctx: &mut Context, a: &Self::FieldPoint) { - self.enforce_less_than_p(ctx, a) + fn enforce_less_than( + &self, + ctx: &mut Context, + a: ProperCrtUint, + ) -> Reduced, Fp> { + self.enforce_less_than_p(ctx, a.clone()); + Reduced(a, PhantomData) } - fn is_soft_zero(&self, ctx: &mut Context, a: &CRTInteger) -> AssignedValue { - big_is_zero::crt::(self.gate(), ctx, a) - - // CHECK: I don't think this is necessary: - // underflow != 0 iff carry < p - // let p = self.load_constant(ctx, self.p.to_biguint().unwrap()); - // let (_, underflow) = - // sub::crt::(self.range(), ctx, a, &p, self.limb_bits, self.limb_bases[1]); - // let is_underflow_zero = self.gate().is_zero(ctx, &underflow); - // let range_check = self.gate().not(ctx, Existing(&is_underflow_zero)); - - // self.gate().and(ctx, is_zero, range_check) + /// Returns 1 iff `a` is 0 as a BigUint. This means that even if `a` is 0 modulo `p`, this may return 0. + fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl Into>, + ) -> AssignedValue { + let a = a.into(); + big_is_zero::positive(self.gate(), ctx, a.0.truncation) } - fn is_soft_nonzero(&self, ctx: &mut Context, a: &CRTInteger) -> AssignedValue { - let is_zero = big_is_zero::crt::(self.gate(), ctx, a); + /// Given proper CRT integer `a`, returns 1 iff `a < modulus::()` and `a != 0` as integers + /// + /// # Assumptions + /// * `a` is proper representation of BigUint + fn is_soft_nonzero( + &self, + ctx: &mut Context, + a: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let is_zero = big_is_zero::positive(self.gate(), ctx, a.0.truncation.clone()); let is_nonzero = self.gate().not(ctx, is_zero); // underflow != 0 iff carry < p - let p = self.load_constant(ctx, self.p.to_biguint().unwrap()); + let p = self.load_constant_uint(ctx, self.p.to_biguint().unwrap()); let (_, underflow) = - sub::crt::(self.range(), ctx, a, &p, self.limb_bits, self.limb_bases[1]); + sub::crt::(self.range(), ctx, a, p, self.limb_bits, self.limb_bases[1]); let is_underflow_zero = self.gate().is_zero(ctx, underflow); - let range_check = self.gate().not(ctx, is_underflow_zero); + let no_underflow = self.gate().not(ctx, is_underflow_zero); - self.gate().and(ctx, is_nonzero, range_check) + self.gate().and(ctx, is_nonzero, no_underflow) } // assuming `a` has been range checked to be a proper BigInt // constrain the witness `a` to be `< p` // then check if `a` is 0 - fn is_zero(&self, ctx: &mut Context, a: &CRTInteger) -> AssignedValue { - self.enforce_less_than_p(ctx, a); + fn is_zero(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + let a = a.into(); + self.enforce_less_than_p(ctx, a.clone()); // just check truncated limbs are all 0 since they determine the native value - big_is_zero::positive::(self.gate(), ctx, &a.truncation) + big_is_zero::positive(self.gate(), ctx, a.0.truncation) } fn is_equal_unenforced( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, + a: Reduced, Fp>, + b: Reduced, Fp>, ) -> AssignedValue { - big_is_equal::assign::(self.gate(), ctx, &a.truncation, &b.truncation) + big_is_equal::assign::(self.gate(), ctx, a.0, b.0) } // assuming `a, b` have been range checked to be a proper BigInt // constrain the witnesses `a, b` to be `< p` // then assert `a == b` as BigInts - fn assert_equal(&self, ctx: &mut Context, a: &Self::FieldPoint, b: &Self::FieldPoint) { - self.enforce_less_than_p(ctx, a); - self.enforce_less_than_p(ctx, b); + fn assert_equal( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) { + let a = a.into(); + let b = b.into(); // a.native and b.native are derived from `a.truncation, b.truncation`, so no need to check if they're equal - for (limb_a, limb_b) in a.truncation.limbs.iter().zip(b.truncation.limbs.iter()) { + for (limb_a, limb_b) in a.limbs().iter().zip(b.limbs().iter()) { ctx.constrain_equal(limb_a, limb_b); } + self.enforce_less_than_p(ctx, a); + self.enforce_less_than_p(ctx, b); } } -impl<'range, F: PrimeField, Fp: PrimeField> Selectable for FpChip<'range, F, Fp> { - type Point = CRTInteger; - +impl<'range, F: PrimeField, Fp: PrimeField> Selectable> for FpChip<'range, F, Fp> { fn select( &self, ctx: &mut Context, - a: &CRTInteger, - b: &CRTInteger, + a: CRTInteger, + b: CRTInteger, sel: AssignedValue, ) -> CRTInteger { - select::crt::(self.range.gate(), ctx, a, b, sel) + select::crt(self.gate(), ctx, a, b, sel) } fn select_by_indicator( &self, ctx: &mut Context, - a: &[CRTInteger], + a: &impl AsRef<[CRTInteger]>, coeffs: &[AssignedValue], ) -> CRTInteger { - select_by_indicator::crt::(self.range.gate(), ctx, a, coeffs, &self.limb_bases) + select_by_indicator::crt(self.gate(), ctx, a.as_ref(), coeffs, &self.limb_bases) + } +} + +impl<'range, F: PrimeField, Fp: PrimeField> Selectable> + for FpChip<'range, F, Fp> +{ + fn select( + &self, + ctx: &mut Context, + a: ProperCrtUint, + b: ProperCrtUint, + sel: AssignedValue, + ) -> ProperCrtUint { + ProperCrtUint(select::crt(self.gate(), ctx, a.0, b.0, sel)) + } + + fn select_by_indicator( + &self, + ctx: &mut Context, + a: &impl AsRef<[ProperCrtUint]>, + coeffs: &[AssignedValue], + ) -> ProperCrtUint { + let out = select_by_indicator::crt(self.gate(), ctx, a.as_ref(), coeffs, &self.limb_bases); + ProperCrtUint(out) + } +} + +impl Selectable> for FC +where + FC: Selectable, +{ + fn select( + &self, + ctx: &mut Context, + a: Reduced, + b: Reduced, + sel: AssignedValue, + ) -> Reduced { + Reduced(self.select(ctx, a.0, b.0, sel), PhantomData) + } + + fn select_by_indicator( + &self, + ctx: &mut Context, + a: &impl AsRef<[Reduced]>, + coeffs: &[AssignedValue], + ) -> Reduced { + // this is inefficient, could do std::mem::transmute but that is unsafe. hopefully compiler optimizes it out + let a = a.as_ref().iter().map(|a| a.0.clone()).collect::>(); + Reduced(self.select_by_indicator(ctx, &a, coeffs), PhantomData) } } diff --git a/halo2-ecc/src/fields/fp12.rs b/halo2-ecc/src/fields/fp12.rs index b82305ca..156ca452 100644 --- a/halo2-ecc/src/fields/fp12.rs +++ b/halo2-ecc/src/fields/fp12.rs @@ -1,13 +1,15 @@ -use super::{FieldChip, FieldExtConstructor, FieldExtPoint, PrimeField, PrimeFieldChip}; -use crate::halo2_proofs::arithmetic::Field; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::fe_to_biguint, - AssignedValue, Context, -}; -use num_bigint::{BigInt, BigUint}; use std::marker::PhantomData; +use halo2_base::{utils::modulus, AssignedValue, Context}; +use num_bigint::BigUint; + +use crate::impl_field_ext_chip_common; + +use super::{ + vector::{FieldVector, FieldVectorChip}, + FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, +}; + /// Represent Fp12 point as FqPoint with degree = 12 /// `Fp12 = Fp2[w] / (w^6 - u - xi)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to @@ -15,264 +17,151 @@ use std::marker::PhantomData; /// This means we store an Fp12 point as `\sum_{i = 0}^6 (a_{i0} + a_{i1} * u) * w^i` /// This is encoded in an FqPoint of degree 12 as `(a_{00}, ..., a_{50}, a_{01}, ..., a_{51})` #[derive(Clone, Copy, Debug)] -pub struct Fp12Chip<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp12: Field, const XI_0: i64> -where - FpChip::FieldType: PrimeField, -{ - // for historical reasons, leaving this as a reference - // for the current implementation we could also just use the de-referenced version: `fp_chip: FpChip` - pub fp_chip: &'a FpChip, - _f: PhantomData, - _fp12: PhantomData, -} +pub struct Fp12Chip<'a, F: PrimeField, FpChip: FieldChip, Fp12, const XI_0: i64>( + pub FieldVectorChip<'a, F, FpChip>, + PhantomData, +); impl<'a, F, FpChip, Fp12, const XI_0: i64> Fp12Chip<'a, F, FpChip, Fp12, XI_0> where F: PrimeField, FpChip: PrimeFieldChip, FpChip::FieldType: PrimeField, - Fp12: Field + FieldExtConstructor, + Fp12: ff::Field, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. pub fn new(fp_chip: &'a FpChip) -> Self { - Self { fp_chip, _f: PhantomData, _fp12: PhantomData } + assert_eq!( + modulus::() % 4usize, + BigUint::from(3u64), + "p must be 3 (mod 4) for the polynomial u^2 + 1 to be irreducible" + ); + Self(FieldVectorChip::new(fp_chip), PhantomData) + } + + pub fn fp_chip(&self) -> &FpChip { + self.0.fp_chip } pub fn fp2_mul_no_carry( &self, ctx: &mut Context, - a: &FieldExtPoint, - fp2_pt: &FieldExtPoint, - ) -> FieldExtPoint { - assert_eq!(a.coeffs.len(), 12); - assert_eq!(fp2_pt.coeffs.len(), 2); - + fp12_pt: FieldVector, + fp2_pt: FieldVector, + ) -> FieldVector { + let fp12_pt = fp12_pt.0; + let fp2_pt = fp2_pt.0; + assert_eq!(fp12_pt.len(), 12); + assert_eq!(fp2_pt.len(), 2); + + let fp_chip = self.fp_chip(); let mut out_coeffs = Vec::with_capacity(12); for i in 0..6 { - let coeff1 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &fp2_pt.coeffs[0]); - let coeff2 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &fp2_pt.coeffs[1]); - let coeff = self.fp_chip.sub_no_carry(ctx, &coeff1, &coeff2); + let coeff1 = fp_chip.mul_no_carry(ctx, fp12_pt[i].clone(), fp2_pt[0].clone()); + let coeff2 = fp_chip.mul_no_carry(ctx, fp12_pt[i + 6].clone(), fp2_pt[1].clone()); + let coeff = fp_chip.sub_no_carry(ctx, coeff1, coeff2); out_coeffs.push(coeff); } for i in 0..6 { - let coeff1 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &fp2_pt.coeffs[0]); - let coeff2 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &fp2_pt.coeffs[1]); - let coeff = self.fp_chip.add_no_carry(ctx, &coeff1, &coeff2); + let coeff1 = fp_chip.mul_no_carry(ctx, fp12_pt[i + 6].clone(), fp2_pt[0].clone()); + let coeff2 = fp_chip.mul_no_carry(ctx, fp12_pt[i].clone(), fp2_pt[1].clone()); + let coeff = fp_chip.add_no_carry(ctx, coeff1, coeff2); out_coeffs.push(coeff); } - FieldExtPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // for \sum_i (a_i + b_i u) w^i, returns \sum_i (-1)^i (a_i + b_i u) w^i pub fn conjugate( &self, ctx: &mut Context, - a: &FieldExtPoint, - ) -> FieldExtPoint { - assert_eq!(a.coeffs.len(), 12); + a: FieldVector, + ) -> FieldVector { + let a = a.0; + assert_eq!(a.len(), 12); let coeffs = a - .coeffs - .iter() + .into_iter() .enumerate() - .map(|(i, c)| if i % 2 == 0 { c.clone() } else { self.fp_chip.negate(ctx, c) }) + .map(|(i, c)| if i % 2 == 0 { c } else { self.fp_chip().negate(ctx, c) }) .collect(); - FieldExtPoint::construct(coeffs) + FieldVector(coeffs) } } -/// multiply (a0 + a1 * u) * (XI0 + u) without carry +/// multiply Fp2 elts: (a0 + a1 * u) * (XI0 + u) without carry +/// +/// # Assumptions +/// * `a` is `Fp2` point represented as `FieldVector` with degree = 2 pub fn mul_no_carry_w6, const XI_0: i64>( fp_chip: &FC, ctx: &mut Context, - a: &FieldExtPoint, -) -> FieldExtPoint { - assert_eq!(a.coeffs.len(), 2); - let (a0, a1) = (&a.coeffs[0], &a.coeffs[1]); + a: FieldVector, +) -> FieldVector { + let [a0, a1]: [_; 2] = a.0.try_into().unwrap(); // (a0 + a1 u) * (XI_0 + u) = (a0 * XI_0 - a1) + (a1 * XI_0 + a0) u with u^2 = -1 // This should fit in the overflow representation if limb_bits is large enough - let a0_xi0 = fp_chip.scalar_mul_no_carry(ctx, a0, XI_0); - let out0_0_nocarry = fp_chip.sub_no_carry(ctx, &a0_xi0, a1); + let a0_xi0 = fp_chip.scalar_mul_no_carry(ctx, a0.clone(), XI_0); + let out0_0_nocarry = fp_chip.sub_no_carry(ctx, a0_xi0, a1.clone()); let out0_1_nocarry = fp_chip.scalar_mul_and_add_no_carry(ctx, a1, a0, XI_0); - FieldExtPoint::construct(vec![out0_0_nocarry, out0_1_nocarry]) + FieldVector(vec![out0_0_nocarry, out0_1_nocarry]) } // a lot of this is common to any field extension (lots of for loops), but due to the way rust traits work, it is hard to create a common generic trait that does this. The main problem is that if you had a `FieldExtCommon` trait and wanted to implement `FieldChip` for anything with `FieldExtCommon`, rust will stop you because someone could implement `FieldExtCommon` and `FieldChip` for the same type, causing a conflict. +// partially solved using macro + impl<'a, F, FpChip, Fp12, const XI_0: i64> FieldChip for Fp12Chip<'a, F, FpChip, Fp12, XI_0> where F: PrimeField, - FpChip: PrimeFieldChip, + FpChip: PrimeFieldChip, FpChip::FieldType: PrimeField, - Fp12: Field + FieldExtConstructor, + Fp12: ff::Field + FieldExtConstructor, + FieldVector: From>, + FieldVector: From>, { const PRIME_FIELD_NUM_BITS: u32 = FpChip::FieldType::NUM_BITS; - type ConstantType = Fp12; - type WitnessType = Vec; - type FieldPoint = FieldExtPoint; + type UnsafeFieldPoint = FieldVector; + type FieldPoint = FieldVector; + type ReducedFieldPoint = FieldVector; type FieldType = Fp12; type RangeChip = FpChip::RangeChip; - fn native_modulus(&self) -> &BigUint { - self.fp_chip.native_modulus() - } - fn range(&self) -> &Self::RangeChip { - self.fp_chip.range() - } - - fn limb_bits(&self) -> usize { - self.fp_chip.limb_bits() - } - - fn get_assigned_value(&self, x: &Self::FieldPoint) -> Fp12 { - assert_eq!(x.coeffs.len(), 12); - let values = - x.coeffs.iter().map(|v| self.fp_chip.get_assigned_value(v)).collect::>(); + fn get_assigned_value(&self, x: &Self::UnsafeFieldPoint) -> Fp12 { + assert_eq!(x.0.len(), 12); + let values = x.0.iter().map(|v| self.fp_chip().get_assigned_value(v)).collect::>(); Fp12::new(values.try_into().unwrap()) } - fn fe_to_constant(x: Self::FieldType) -> Self::ConstantType { - x - } - fn fe_to_witness(x: &Fp12) -> Vec { - x.coeffs().iter().map(|c| BigInt::from(fe_to_biguint(c))).collect() - } - - fn load_private(&self, ctx: &mut Context, coeffs: Vec) -> Self::FieldPoint { - assert_eq!(coeffs.len(), 12); - let mut assigned_coeffs = Vec::with_capacity(12); - for a in coeffs { - let assigned_coeff = self.fp_chip.load_private(ctx, a.clone()); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - fn load_constant(&self, ctx: &mut Context, c: Fp12) -> Self::FieldPoint { - let mut assigned_coeffs = Vec::with_capacity(12); - for a in &c.coeffs() { - let assigned_coeff = self.fp_chip.load_constant(ctx, fe_to_biguint(a)); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - // signed overflow BigInt functions - fn add_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> Self::FieldPoint { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn add_constant_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - c: Self::ConstantType, - ) -> Self::FieldPoint { - let c_coeffs = c.coeffs(); - assert_eq!(a.coeffs.len(), c_coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for (a, c) in a.coeffs.iter().zip(c_coeffs.into_iter()) { - let coeff = self.fp_chip.add_constant_no_carry(ctx, a, FpChip::fe_to_constant(c)); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn sub_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> Self::FieldPoint { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.sub_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn negate(&self, ctx: &mut Context, a: &Self::FieldPoint) -> Self::FieldPoint { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let out_coeff = self.fp_chip.negate(ctx, a_coeff); - out_coeffs.push(out_coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - c: i64, - ) -> Self::FieldPoint { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.scalar_mul_no_carry(ctx, &a.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_and_add_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - c: i64, - ) -> Self::FieldPoint { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = - self.fp_chip.scalar_mul_and_add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - // w^6 = u + xi for xi = 9 fn mul_no_carry( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> Self::FieldPoint { - assert_eq!(a.coeffs.len(), 12); - assert_eq!(b.coeffs.len(), 12); - + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + let a = a.into().0; + let b = b.into().0; + assert_eq!(a.len(), 12); + assert_eq!(b.len(), 12); + + let fp_chip = self.fp_chip(); // a = \sum_{i = 0}^5 (a_i * w^i + a_{i + 6} * w^i * u) // b = \sum_{i = 0}^5 (b_i * w^i + b_{i + 6} * w^i * u) - let mut a0b0_coeffs = Vec::with_capacity(11); - let mut a0b1_coeffs = Vec::with_capacity(11); - let mut a1b0_coeffs = Vec::with_capacity(11); - let mut a1b1_coeffs = Vec::with_capacity(11); + let mut a0b0_coeffs: Vec = Vec::with_capacity(11); + let mut a0b1_coeffs: Vec = Vec::with_capacity(11); + let mut a1b0_coeffs: Vec = Vec::with_capacity(11); + let mut a1b1_coeffs: Vec = Vec::with_capacity(11); for i in 0..6 { for j in 0..6 { - let coeff00 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &b.coeffs[j]); - let coeff01 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &b.coeffs[j + 6]); - let coeff10 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &b.coeffs[j]); - let coeff11 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &b.coeffs[j + 6]); + let coeff00 = fp_chip.mul_no_carry(ctx, &a[i], &b[j]); + let coeff01 = fp_chip.mul_no_carry(ctx, &a[i], &b[j + 6]); + let coeff10 = fp_chip.mul_no_carry(ctx, &a[i + 6], &b[j]); + let coeff11 = fp_chip.mul_no_carry(ctx, &a[i + 6], &b[j + 6]); if i + j < a0b0_coeffs.len() { - a0b0_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a0b0_coeffs[i + j], &coeff00); - a0b1_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a0b1_coeffs[i + j], &coeff01); - a1b0_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a1b0_coeffs[i + j], &coeff10); - a1b1_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a1b1_coeffs[i + j], &coeff11); + a0b0_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a0b0_coeffs[i + j], coeff00); + a0b1_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a0b1_coeffs[i + j], coeff01); + a1b0_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a1b0_coeffs[i + j], coeff10); + a1b1_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a1b1_coeffs[i + j], coeff11); } else { a0b0_coeffs.push(coeff00); a0b1_coeffs.push(coeff01); @@ -285,10 +174,8 @@ where let mut a0b0_minus_a1b1 = Vec::with_capacity(11); let mut a0b1_plus_a1b0 = Vec::with_capacity(11); for i in 0..11 { - let a0b0_minus_a1b1_entry = - self.fp_chip.sub_no_carry(ctx, &a0b0_coeffs[i], &a1b1_coeffs[i]); - let a0b1_plus_a1b0_entry = - self.fp_chip.add_no_carry(ctx, &a0b1_coeffs[i], &a1b0_coeffs[i]); + let a0b0_minus_a1b1_entry = fp_chip.sub_no_carry(ctx, &a0b0_coeffs[i], &a1b1_coeffs[i]); + let a0b1_plus_a1b0_entry = fp_chip.add_no_carry(ctx, &a0b1_coeffs[i], &a1b0_coeffs[i]); a0b0_minus_a1b1.push(a0b0_minus_a1b1_entry); a0b1_plus_a1b0.push(a0b1_plus_a1b0_entry); @@ -299,13 +186,13 @@ where let mut out_coeffs = Vec::with_capacity(12); for i in 0..6 { if i < 5 { - let mut coeff = self.fp_chip.scalar_mul_and_add_no_carry( + let mut coeff = fp_chip.scalar_mul_and_add_no_carry( ctx, &a0b0_minus_a1b1[i + 6], &a0b0_minus_a1b1[i], XI_0, ); - coeff = self.fp_chip.sub_no_carry(ctx, &coeff, &a0b1_plus_a1b0[i + 6]); + coeff = fp_chip.sub_no_carry(ctx, coeff, &a0b1_plus_a1b0[i + 6]); out_coeffs.push(coeff); } else { out_coeffs.push(a0b0_minus_a1b1[i].clone()); @@ -314,131 +201,18 @@ where for i in 0..6 { if i < 5 { let mut coeff = - self.fp_chip.add_no_carry(ctx, &a0b1_plus_a1b0[i], &a0b0_minus_a1b1[i + 6]); - coeff = self.fp_chip.scalar_mul_and_add_no_carry( - ctx, - &a0b1_plus_a1b0[i + 6], - &coeff, - XI_0, - ); + fp_chip.add_no_carry(ctx, &a0b1_plus_a1b0[i], &a0b0_minus_a1b1[i + 6]); + coeff = + fp_chip.scalar_mul_and_add_no_carry(ctx, &a0b1_plus_a1b0[i + 6], coeff, XI_0); out_coeffs.push(coeff); } else { out_coeffs.push(a0b1_plus_a1b0[i].clone()); } } - Self::FieldPoint::construct(out_coeffs) - } - - fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: &Self::FieldPoint) { - for coeff in &a.coeffs { - self.fp_chip.check_carry_mod_to_zero(ctx, coeff); - } - } - - fn carry_mod(&self, ctx: &mut Context, a: &Self::FieldPoint) -> Self::FieldPoint { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.carry_mod(ctx, a_coeff); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn range_check(&self, ctx: &mut Context, a: &Self::FieldPoint, max_bits: usize) { - for a_coeff in &a.coeffs { - self.fp_chip.range_check(ctx, a_coeff, max_bits); - } - } - - fn enforce_less_than(&self, ctx: &mut Context, a: &Self::FieldPoint) { - for a_coeff in &a.coeffs { - self.fp_chip.enforce_less_than(ctx, a_coeff) - } - } - - fn is_soft_zero(&self, ctx: &mut Context, a: &Self::FieldPoint) -> AssignedValue { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().and(ctx, coeff, p); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_soft_nonzero(&self, ctx: &mut Context, a: &Self::FieldPoint) -> AssignedValue { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.gate().or(ctx, coeff, p); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_zero(&self, ctx: &mut Context, a: &Self::FieldPoint) -> AssignedValue { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.gate().and(ctx, coeff, p); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() + FieldVector(out_coeffs) } - fn is_equal( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> AssignedValue { - let mut acc = None; - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - let coeff = self.fp_chip.is_equal(ctx, a_coeff, b_coeff); - if let Some(c) = acc { - acc = Some(self.gate().and(ctx, coeff, c)); - } else { - acc = Some(coeff); - } - } - acc.unwrap() - } - - fn is_equal_unenforced( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> AssignedValue { - let mut acc = None; - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff); - if let Some(c) = acc { - acc = Some(self.gate().and(ctx, coeff, c)); - } else { - acc = Some(coeff); - } - } - acc.unwrap() - } - - fn assert_equal(&self, ctx: &mut Context, a: &Self::FieldPoint, b: &Self::FieldPoint) { - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - self.fp_chip.assert_equal(ctx, a_coeff, b_coeff); - } - } + impl_field_ext_chip_common!(); } mod bn254 { diff --git a/halo2-ecc/src/fields/fp2.rs b/halo2-ecc/src/fields/fp2.rs index aed390fa..55e3243a 100644 --- a/halo2-ecc/src/fields/fp2.rs +++ b/halo2-ecc/src/fields/fp2.rs @@ -1,94 +1,66 @@ +use std::fmt::Debug; +use std::marker::PhantomData; + +use halo2_base::{utils::modulus, AssignedValue, Context}; +use num_bigint::BigUint; + +use crate::impl_field_ext_chip_common; + use super::{ - FieldChip, FieldExtConstructor, FieldExtPoint, PrimeField, PrimeFieldChip, Selectable, + vector::{FieldVector, FieldVectorChip}, + FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, }; -use crate::halo2_proofs::arithmetic::Field; -use halo2_base::{gates::GateInstructions, utils::fe_to_biguint, AssignedValue, Context}; -use num_bigint::{BigInt, BigUint}; -use std::marker::PhantomData; -/// Represent Fp2 point as `FieldExtPoint` with degree = 2 +/// Represent Fp2 point as `FieldVector` with degree = 2 /// `Fp2 = Fp[u] / (u^2 + 1)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to be irreducible over Fp; i.e., in order for -1 to not be a square (quadratic residue) in Fp /// This means we store an Fp2 point as `a_0 + a_1 * u` where `a_0, a_1 in Fp` #[derive(Clone, Copy, Debug)] -pub struct Fp2Chip<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp2: Field> -where - FpChip::FieldType: PrimeField, -{ - // for historical reasons, leaving this as a reference - // for the current implementation we could also just use the de-referenced version: `fp_chip: FpChip` - pub fp_chip: &'a FpChip, - _f: PhantomData, - _fp2: PhantomData, -} +pub struct Fp2Chip<'a, F: PrimeField, FpChip: FieldChip, Fp2>( + pub FieldVectorChip<'a, F, FpChip>, + PhantomData, +); -impl<'a, F, FpChip, Fp2> Fp2Chip<'a, F, FpChip, Fp2> +impl<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp2: ff::Field> Fp2Chip<'a, F, FpChip, Fp2> where - F: PrimeField, - FpChip: PrimeFieldChip, FpChip::FieldType: PrimeField, - Fp2: Field + FieldExtConstructor, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. pub fn new(fp_chip: &'a FpChip) -> Self { - Self { fp_chip, _f: PhantomData, _fp2: PhantomData } + assert_eq!( + modulus::() % 4usize, + BigUint::from(3u64), + "p must be 3 (mod 4) for the polynomial u^2 + 1 to be irreducible" + ); + Self(FieldVectorChip::new(fp_chip), PhantomData) } - pub fn fp_mul_no_carry( - &self, - ctx: &mut Context, - a: &FieldExtPoint, - fp_point: &FpChip::FieldPoint, - ) -> FieldExtPoint { - assert_eq!(a.coeffs.len(), 2); - - let mut out_coeffs = Vec::with_capacity(2); - for c in &a.coeffs { - let coeff = self.fp_chip.mul_no_carry(ctx, c, fp_point); - out_coeffs.push(coeff); - } - FieldExtPoint::construct(out_coeffs) + pub fn fp_chip(&self) -> &FpChip { + self.0.fp_chip } pub fn conjugate( &self, ctx: &mut Context, - a: &FieldExtPoint, - ) -> FieldExtPoint { - assert_eq!(a.coeffs.len(), 2); + a: FieldVector, + ) -> FieldVector { + let mut a = a.0; + assert_eq!(a.len(), 2); - let neg_a1 = self.fp_chip.negate(ctx, &a.coeffs[1]); - FieldExtPoint::construct(vec![a.coeffs[0].clone(), neg_a1]) + let neg_a1 = self.fp_chip().negate(ctx, a.pop().unwrap()); + FieldVector(vec![a.pop().unwrap(), neg_a1]) } pub fn neg_conjugate( &self, ctx: &mut Context, - a: &FieldExtPoint, - ) -> FieldExtPoint { - assert_eq!(a.coeffs.len(), 2); + a: FieldVector, + ) -> FieldVector { + assert_eq!(a.0.len(), 2); + let mut a = a.0.into_iter(); - let neg_a0 = self.fp_chip.negate(ctx, &a.coeffs[0]); - FieldExtPoint::construct(vec![neg_a0, a.coeffs[1].clone()]) - } - - pub fn select( - &self, - ctx: &mut Context, - a: &FieldExtPoint, - b: &FieldExtPoint, - sel: AssignedValue, - ) -> FieldExtPoint - where - FpChip: Selectable, - { - let coeffs: Vec<_> = a - .coeffs - .iter() - .zip(b.coeffs.iter()) - .map(|(a, b)| self.fp_chip.select(ctx, a, b, sel)) - .collect(); - FieldExtPoint::construct(coeffs) + let neg_a0 = self.fp_chip().negate(ctx, a.next().unwrap()); + FieldVector(vec![neg_a0, a.next().unwrap()]) } } @@ -96,268 +68,52 @@ impl<'a, F, FpChip, Fp2> FieldChip for Fp2Chip<'a, F, FpChip, Fp2> where F: PrimeField, FpChip::FieldType: PrimeField, - FpChip: PrimeFieldChip, - Fp2: Field + FieldExtConstructor, + FpChip: PrimeFieldChip, + Fp2: ff::Field + FieldExtConstructor, + FieldVector: From>, + FieldVector: From>, { const PRIME_FIELD_NUM_BITS: u32 = FpChip::FieldType::NUM_BITS; - type ConstantType = Fp2; - type WitnessType = Vec; - type FieldPoint = FieldExtPoint; + type UnsafeFieldPoint = FieldVector; + type FieldPoint = FieldVector; + type ReducedFieldPoint = FieldVector; type FieldType = Fp2; type RangeChip = FpChip::RangeChip; - fn native_modulus(&self) -> &BigUint { - self.fp_chip.native_modulus() - } - fn range(&self) -> &Self::RangeChip { - self.fp_chip.range() - } - - fn limb_bits(&self) -> usize { - self.fp_chip.limb_bits() - } - - fn get_assigned_value(&self, x: &Self::FieldPoint) -> Fp2 { - debug_assert_eq!(x.coeffs.len(), 2); - let c0 = self.fp_chip.get_assigned_value(&x.coeffs[0]); - let c1 = self.fp_chip.get_assigned_value(&x.coeffs[1]); + fn get_assigned_value(&self, x: &Self::UnsafeFieldPoint) -> Fp2 { + assert_eq!(x.0.len(), 2); + let c0 = self.fp_chip().get_assigned_value(&x[0]); + let c1 = self.fp_chip().get_assigned_value(&x[1]); Fp2::new([c0, c1]) } - fn fe_to_constant(x: Fp2) -> Fp2 { - x - } - - fn fe_to_witness(x: &Fp2) -> Vec { - let coeffs = x.coeffs(); - debug_assert_eq!(coeffs.len(), 2); - coeffs.iter().map(|c| BigInt::from(fe_to_biguint(c))).collect() - } - - fn load_private(&self, ctx: &mut Context, coeffs: Vec) -> Self::FieldPoint { - debug_assert_eq!(coeffs.len(), 2); - let mut assigned_coeffs = Vec::with_capacity(2); - for a in coeffs { - let assigned_coeff = self.fp_chip.load_private(ctx, a); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - fn load_constant(&self, ctx: &mut Context, c: Fp2) -> Self::FieldPoint { - let mut assigned_coeffs = Vec::with_capacity(2); - for a in &c.coeffs() { - let assigned_coeff = self.fp_chip.load_constant(ctx, fe_to_biguint(a)); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - // signed overflow BigInt functions - fn add_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> Self::FieldPoint { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn add_constant_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - c: Self::ConstantType, - ) -> Self::FieldPoint { - let c_coeffs = c.coeffs(); - assert_eq!(a.coeffs.len(), c_coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for (a, c) in a.coeffs.iter().zip(c_coeffs.into_iter()) { - let coeff = self.fp_chip.add_constant_no_carry(ctx, a, FpChip::fe_to_constant(c)); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn sub_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> Self::FieldPoint { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.sub_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn negate(&self, ctx: &mut Context, a: &Self::FieldPoint) -> Self::FieldPoint { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let out_coeff = self.fp_chip.negate(ctx, a_coeff); - out_coeffs.push(out_coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - c: i64, - ) -> Self::FieldPoint { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.scalar_mul_no_carry(ctx, &a.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_and_add_no_carry( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - c: i64, - ) -> Self::FieldPoint { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = - self.fp_chip.scalar_mul_and_add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - fn mul_no_carry( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> Self::FieldPoint { - assert_eq!(a.coeffs.len(), b.coeffs.len()); + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + let a = a.into().0; + let b = b.into().0; + assert_eq!(a.len(), 2); + assert_eq!(b.len(), 2); + let fp_chip = self.fp_chip(); // (a_0 + a_1 * u) * (b_0 + b_1 * u) = (a_0 b_0 - a_1 b_1) + (a_0 b_1 + a_1 b_0) * u - let mut ab_coeffs = Vec::with_capacity(a.coeffs.len() * b.coeffs.len()); - for i in 0..a.coeffs.len() { - for j in 0..b.coeffs.len() { - let coeff = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &b.coeffs[j]); + let mut ab_coeffs = Vec::with_capacity(4); + for a_i in a { + for b_j in b.iter() { + let coeff = fp_chip.mul_no_carry(ctx, &a_i, b_j); ab_coeffs.push(coeff); } } - let a0b0_minus_a1b1 = - self.fp_chip.sub_no_carry(ctx, &ab_coeffs[0], &ab_coeffs[b.coeffs.len() + 1]); - let a0b1_plus_a1b0 = - self.fp_chip.add_no_carry(ctx, &ab_coeffs[1], &ab_coeffs[b.coeffs.len()]); - - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - out_coeffs.push(a0b0_minus_a1b1); - out_coeffs.push(a0b1_plus_a1b0); - - Self::FieldPoint::construct(out_coeffs) - } - - fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: &Self::FieldPoint) { - for coeff in &a.coeffs { - self.fp_chip.check_carry_mod_to_zero(ctx, coeff); - } - } - - fn carry_mod(&self, ctx: &mut Context, a: &Self::FieldPoint) -> Self::FieldPoint { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.carry_mod(ctx, a_coeff); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn range_check(&self, ctx: &mut Context, a: &Self::FieldPoint, max_bits: usize) { - for a_coeff in &a.coeffs { - self.fp_chip.range_check(ctx, a_coeff, max_bits); - } - } - - fn enforce_less_than(&self, ctx: &mut Context, a: &Self::FieldPoint) { - for a_coeff in &a.coeffs { - self.fp_chip.enforce_less_than(ctx, a_coeff) - } - } - - fn is_soft_zero(&self, ctx: &mut Context, a: &Self::FieldPoint) -> AssignedValue { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.gate().and(ctx, coeff, p); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_soft_nonzero(&self, ctx: &mut Context, a: &Self::FieldPoint) -> AssignedValue { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.gate().or(ctx, coeff, p); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_zero(&self, ctx: &mut Context, a: &Self::FieldPoint) -> AssignedValue { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.gate().and(ctx, coeff, p); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } + let a0b0_minus_a1b1 = fp_chip.sub_no_carry(ctx, &ab_coeffs[0], &ab_coeffs[3]); + let a0b1_plus_a1b0 = fp_chip.add_no_carry(ctx, &ab_coeffs[1], &ab_coeffs[2]); - fn is_equal_unenforced( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> AssignedValue { - let mut acc = None; - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff); - if let Some(c) = acc { - acc = Some(self.gate().and(ctx, coeff, c)); - } else { - acc = Some(coeff); - } - } - acc.unwrap() + FieldVector(vec![a0b0_minus_a1b1, a0b1_plus_a1b0]) } - fn assert_equal(&self, ctx: &mut Context, a: &Self::FieldPoint, b: &Self::FieldPoint) { - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - self.fp_chip.assert_equal(ctx, a_coeff, b_coeff) - } - } + // ========= inherited from FieldVectorChip ========= + impl_field_ext_chip_common!(); } mod bn254 { diff --git a/halo2-ecc/src/fields/mod.rs b/halo2-ecc/src/fields/mod.rs index cdae8275..0c55affa 100644 --- a/halo2-ecc/src/fields/mod.rs +++ b/halo2-ecc/src/fields/mod.rs @@ -1,6 +1,6 @@ use crate::halo2_proofs::arithmetic::Field; use halo2_base::{ - gates::RangeInstructions, + gates::{GateInstructions, RangeInstructions}, utils::{BigPrimeField, ScalarField}, AssignedValue, Context, }; @@ -11,37 +11,42 @@ use std::fmt::Debug; pub mod fp; pub mod fp12; pub mod fp2; +pub mod vector; #[cfg(test)] mod tests; pub trait PrimeField = BigPrimeField; -#[derive(Clone, Debug)] -pub struct FieldExtPoint { - // `F_q` field extension of `F_p` where `q = p^degree` - // An `F_q` point consists of `degree` number of `F_p` points - // The `F_p` points are stored as `FieldPoint`s - - // We do not specify the irreducible `F_p` polynomial used to construct `F_q` here - that is implementation specific - pub coeffs: Vec, - // `degree = coeffs.len()` -} - -impl FieldExtPoint { - pub fn construct(coeffs: Vec) -> Self { - Self { coeffs } - } -} - -/// Common functionality for finite field chips -pub trait FieldChip: Clone + Debug + Send + Sync { +/// Trait for common functionality for finite field chips. +/// Primarily intended to emulate a "non-native" finite field using "native" values in a prime field `F`. +/// Most functions are designed for the case when the non-native field is larger than the native field, but +/// the trait can still be implemented and used in other cases. +pub trait FieldChip: Clone + Send + Sync { const PRIME_FIELD_NUM_BITS: u32; - type ConstantType: Debug; - type WitnessType: Debug; - type FieldPoint: Clone + Debug + Send + Sync; - // a type implementing `Field` trait to help with witness generation (for example with inverse) + /// A representation of a field element that is used for intermediate computations. + /// The representation can have "overflows" (e.g., overflow limbs or negative limbs). + type UnsafeFieldPoint: Clone + + Debug + + Send + + Sync + + From + + for<'a> From<&'a Self::UnsafeFieldPoint> + + for<'a> From<&'a Self::FieldPoint>; // Cloning all the time impacts readability, so we allow references to be cloned into owned values + + /// The "proper" representation of a field element. Allowed to be a non-unique representation of a field element (e.g., can be greater than modulus) + type FieldPoint: Clone + + Debug + + Send + + Sync + + From + + for<'a> From<&'a Self::FieldPoint>; + + /// A proper representation of field elements that guarantees a unique representation of each field element. Typically this means Uints that are less than the modulus. + type ReducedFieldPoint: Clone + Debug + Send + Sync; + + /// A type implementing `Field` trait to help with witness generation (for example with inverse) type FieldType: Field; type RangeChip: RangeInstructions; @@ -52,81 +57,124 @@ pub trait FieldChip: Clone + Debug + Send + Sync { fn range(&self) -> &Self::RangeChip; fn limb_bits(&self) -> usize; - fn get_assigned_value(&self, x: &Self::FieldPoint) -> Self::FieldType; + fn get_assigned_value(&self, x: &Self::UnsafeFieldPoint) -> Self::FieldType; - fn fe_to_constant(x: Self::FieldType) -> Self::ConstantType; - fn fe_to_witness(x: &Self::FieldType) -> Self::WitnessType; + /// Assigns `fe` as private witness. Note that the witness may **not** be constrained to be a unique representation of the field element `fe`. + fn load_private(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint; - fn load_private(&self, ctx: &mut Context, coeffs: Self::WitnessType) -> Self::FieldPoint; + /// Assigns `fe` as private witness and contrains the witness to be in reduced form. + fn load_private_reduced( + &self, + ctx: &mut Context, + fe: Self::FieldType, + ) -> Self::ReducedFieldPoint { + let fe = self.load_private(ctx, fe); + self.enforce_less_than(ctx, fe) + } - fn load_constant(&self, ctx: &mut Context, coeffs: Self::ConstantType) -> Self::FieldPoint; + /// Assigns `fe` as constant. + fn load_constant(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint; fn add_no_carry( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> Self::FieldPoint; + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint; /// output: `a + c` fn add_constant_no_carry( &self, ctx: &mut Context, - a: &Self::FieldPoint, - c: Self::ConstantType, - ) -> Self::FieldPoint; + a: impl Into, + c: Self::FieldType, + ) -> Self::UnsafeFieldPoint; fn sub_no_carry( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> Self::FieldPoint; + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint; - fn negate(&self, ctx: &mut Context, a: &Self::FieldPoint) -> Self::FieldPoint; + fn negate(&self, ctx: &mut Context, a: Self::FieldPoint) -> Self::FieldPoint; /// a * c fn scalar_mul_no_carry( &self, ctx: &mut Context, - a: &Self::FieldPoint, + a: impl Into, c: i64, - ) -> Self::FieldPoint; + ) -> Self::UnsafeFieldPoint; /// a * c + b fn scalar_mul_and_add_no_carry( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, + a: impl Into, + b: impl Into, c: i64, - ) -> Self::FieldPoint; + ) -> Self::UnsafeFieldPoint; fn mul_no_carry( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> Self::FieldPoint; + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint; - fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: &Self::FieldPoint); + fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint); - fn carry_mod(&self, ctx: &mut Context, a: &Self::FieldPoint) -> Self::FieldPoint; + fn carry_mod(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) -> Self::FieldPoint; - fn range_check(&self, ctx: &mut Context, a: &Self::FieldPoint, max_bits: usize); + fn range_check( + &self, + ctx: &mut Context, + a: impl Into, + max_bits: usize, + ); - fn enforce_less_than(&self, ctx: &mut Context, a: &Self::FieldPoint); + /// Constrains that `a` is a reduced representation and returns the wrapped `a`. + fn enforce_less_than( + &self, + ctx: &mut Context, + a: Self::FieldPoint, + ) -> Self::ReducedFieldPoint; // Returns 1 iff the underlying big integer for `a` is 0. Otherwise returns 0. // For field extensions, checks coordinate-wise. - fn is_soft_zero(&self, ctx: &mut Context, a: &Self::FieldPoint) -> AssignedValue; + fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue; // Constrains that the underlying big integer is in [0, p - 1]. // Then returns 1 iff the underlying big integer for `a` is 0. Otherwise returns 0. // For field extensions, checks coordinate-wise. - fn is_soft_nonzero(&self, ctx: &mut Context, a: &Self::FieldPoint) -> AssignedValue; + fn is_soft_nonzero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue; - fn is_zero(&self, ctx: &mut Context, a: &Self::FieldPoint) -> AssignedValue; + fn is_zero(&self, ctx: &mut Context, a: impl Into) -> AssignedValue; + + fn is_equal_unenforced( + &self, + ctx: &mut Context, + a: Self::ReducedFieldPoint, + b: Self::ReducedFieldPoint, + ) -> AssignedValue; + + fn assert_equal( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ); + + // =========== default implementations ============= // assuming `a, b` have been range checked to be a proper BigInt // constrain the witnesses `a, b` to be `< p` @@ -134,97 +182,117 @@ pub trait FieldChip: Clone + Debug + Send + Sync { fn is_equal( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, + a: impl Into, + b: impl Into, ) -> AssignedValue { - self.enforce_less_than(ctx, a); - self.enforce_less_than(ctx, b); + let a = self.enforce_less_than(ctx, a.into()); + let b = self.enforce_less_than(ctx, b.into()); // a.native and b.native are derived from `a.truncation, b.truncation`, so no need to check if they're equal self.is_equal_unenforced(ctx, a, b) } - fn is_equal_unenforced( - &self, - ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, - ) -> AssignedValue; - - fn assert_equal(&self, ctx: &mut Context, a: &Self::FieldPoint, b: &Self::FieldPoint); - + /// If using `UnsafeFieldPoint`, make sure multiplication does not cause overflow. fn mul( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, + a: impl Into, + b: impl Into, ) -> Self::FieldPoint { let no_carry = self.mul_no_carry(ctx, a, b); - self.carry_mod(ctx, &no_carry) + self.carry_mod(ctx, no_carry) } + /// Constrains that `b` is nonzero as a field element and then returns `a / b`. fn divide( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, + a: impl Into, + b: impl Into, ) -> Self::FieldPoint { - let a_val = self.get_assigned_value(a); - let b_val = self.get_assigned_value(b); - let b_inv = b_val.invert().unwrap(); + let b = b.into(); + let b_is_zero = self.is_zero(ctx, b.clone()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + + self.divide_unsafe(ctx, a.into(), b) + } + + /// Returns `a / b` without constraining `b` to be nonzero. + /// + /// Warning: undefined behavior when `b` is zero. + /// + /// `a, b` must be such that `quot * b - a` without carry does not overflow, where `quot` is the output. + fn divide_unsafe( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let a = a.into(); + let b = b.into(); + let a_val = self.get_assigned_value(&a); + let b_val = self.get_assigned_value(&b); + let b_inv: Self::FieldType = Option::from(b_val.invert()).unwrap_or_default(); let quot_val = a_val * b_inv; - let quot = self.load_private(ctx, Self::fe_to_witness("_val)); + let quot = self.load_private(ctx, quot_val); // constrain quot * b - a = 0 mod p - let quot_b = self.mul_no_carry(ctx, ", b); - let quot_constraint = self.sub_no_carry(ctx, "_b, a); - self.check_carry_mod_to_zero(ctx, "_constraint); + let quot_b = self.mul_no_carry(ctx, quot.clone(), b); + let quot_constraint = self.sub_no_carry(ctx, quot_b, a); + self.check_carry_mod_to_zero(ctx, quot_constraint); quot } - // constrain and output -a / b - // this is usually cheaper constraint-wise than computing -a and then (-a) / b separately + /// Constrains that `b` is nonzero as a field element and then returns `-a / b`. fn neg_divide( &self, ctx: &mut Context, - a: &Self::FieldPoint, - b: &Self::FieldPoint, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let b = b.into(); + let b_is_zero = self.is_zero(ctx, b.clone()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + + self.neg_divide_unsafe(ctx, a.into(), b) + } + + // Returns `-a / b` without constraining `b` to be nonzero. + // this is usually cheaper constraint-wise than computing -a and then (-a) / b separately + fn neg_divide_unsafe( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, ) -> Self::FieldPoint { - let a_val = self.get_assigned_value(a); - let b_val = self.get_assigned_value(b); - let b_inv = b_val.invert().unwrap(); + let a = a.into(); + let b = b.into(); + let a_val = self.get_assigned_value(&a); + let b_val = self.get_assigned_value(&b); + let b_inv: Self::FieldType = Option::from(b_val.invert()).unwrap_or_default(); let quot_val = -a_val * b_inv; - let quot = self.load_private(ctx, Self::fe_to_witness("_val)); - self.range_check(ctx, ", Self::PRIME_FIELD_NUM_BITS as usize); + let quot = self.load_private(ctx, quot_val); // constrain quot * b + a = 0 mod p - let quot_b = self.mul_no_carry(ctx, ", b); - let quot_constraint = self.add_no_carry(ctx, "_b, a); - self.check_carry_mod_to_zero(ctx, "_constraint); + let quot_b = self.mul_no_carry(ctx, quot.clone(), b); + let quot_constraint = self.add_no_carry(ctx, quot_b, a); + self.check_carry_mod_to_zero(ctx, quot_constraint); quot } } -pub trait Selectable { - type Point; - - fn select( - &self, - ctx: &mut Context, - a: &Self::Point, - b: &Self::Point, - sel: AssignedValue, - ) -> Self::Point; +pub trait Selectable { + fn select(&self, ctx: &mut Context, a: Pt, b: Pt, sel: AssignedValue) -> Pt; fn select_by_indicator( &self, ctx: &mut Context, - a: &[Self::Point], + a: &impl AsRef<[Pt]>, coeffs: &[AssignedValue], - ) -> Self::Point; + ) -> Pt; } // Common functionality for prime field chips diff --git a/halo2-ecc/src/fields/tests/fp/assert_eq.rs b/halo2-ecc/src/fields/tests/fp/assert_eq.rs index ef45aa99..5aac74bf 100644 --- a/halo2-ecc/src/fields/tests/fp/assert_eq.rs +++ b/halo2-ecc/src/fields/tests/fp/assert_eq.rs @@ -8,13 +8,10 @@ use halo2_base::{ RangeChip, }, halo2_proofs::{ - halo2curves::bn256::{Fq, Fr}, - plonk::keygen_pk, - plonk::keygen_vk, + halo2curves::bn256::Fq, plonk::keygen_pk, plonk::keygen_vk, poly::kzg::commitment::ParamsKZG, }, }; -use num_bigint::BigInt; use crate::{bn254::FpChip, fields::FieldChip}; use rand::thread_rng; @@ -30,8 +27,8 @@ fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { let chip = FpChip::new(&range, 88, 3); let ctx = builder.main(0); - let a = chip.load_private(ctx, BigInt::from(0)); - let b = chip.load_private(ctx, BigInt::from(0)); + let a = chip.load_private(ctx, Fq::zero()); + let b = chip.load_private(ctx, Fq::zero()); chip.assert_equal(ctx, &a, &b); // set env vars builder.config(k as usize, Some(9)); @@ -51,7 +48,7 @@ fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { let chip = FpChip::new(&range, 88, 3); let ctx = builder.main(0); - let [a, b] = [a, b].map(|x| chip.load_private(ctx, FpChip::::fe_to_witness(&x))); + let [a, b] = [a, b].map(|x| chip.load_private(ctx, x)); chip.assert_equal(ctx, &a, &b); let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points gen_proof(¶ms, &pk, circuit) diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs index a4dfb24e..9489abb5 100644 --- a/halo2-ecc/src/fields/tests/fp/mod.rs +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -27,15 +27,12 @@ fn fp_mul_test( let range = RangeChip::::default(lookup_bits); let chip = FpChip::::new(&range, limb_bits, num_limbs); - let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, FpChip::::fe_to_witness(&x))); - let c = chip.mul(ctx, &a, &b); + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b); - assert_eq!(c.truncation.to_bigint(limb_bits), c.value); - assert_eq!( - c.native.value(), - &biguint_to_fe(&(&c.value.to_biguint().unwrap() % modulus::())) - ); - assert_eq!(c.value, fe_to_biguint(&(_a * _b)).into()) + assert_eq!(c.0.truncation.to_bigint(limb_bits), c.0.value); + assert_eq!(c.native().value(), &biguint_to_fe(&(c.value() % modulus::()))); + assert_eq!(c.0.value, fe_to_biguint(&(_a * _b)).into()) } #[test] diff --git a/halo2-ecc/src/fields/tests/fp12/mod.rs b/halo2-ecc/src/fields/tests/fp12/mod.rs index 7112d690..6fb631b9 100644 --- a/halo2-ecc/src/fields/tests/fp12/mod.rs +++ b/halo2-ecc/src/fields/tests/fp12/mod.rs @@ -25,13 +25,11 @@ fn fp12_mul_test( let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); let chip = Fp12Chip::::new(&fp_chip); - let [a, b] = [_a, _b].map(|x| { - chip.load_private(ctx, Fp12Chip::, Fq12, XI_0>::fe_to_witness(&x)) - }); - let c = chip.mul(ctx, &a, &b); + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b).into(); assert_eq!(chip.get_assigned_value(&c), _a * _b); - for c in c.coeffs { + for c in c.into_iter() { assert_eq!(c.truncation.to_bigint(limb_bits), c.value); } } diff --git a/halo2-ecc/src/fields/vector.rs b/halo2-ecc/src/fields/vector.rs new file mode 100644 index 00000000..6aea9d97 --- /dev/null +++ b/halo2-ecc/src/fields/vector.rs @@ -0,0 +1,495 @@ +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use itertools::Itertools; +use std::{ + marker::PhantomData, + ops::{Index, IndexMut}, +}; + +use crate::bigint::{CRTInteger, ProperCrtUint}; + +use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, Selectable}; + +/// A fixed length vector of `FieldPoint`s +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct FieldVector(pub Vec); + +impl Index for FieldVector { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } +} + +impl IndexMut for FieldVector { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index] + } +} + +impl AsRef<[T]> for FieldVector { + fn as_ref(&self) -> &[T] { + &self.0 + } +} + +impl<'a, T: Clone, U: From> From<&'a FieldVector> for FieldVector { + fn from(other: &'a FieldVector) -> Self { + FieldVector(other.clone().into_iter().map(Into::into).collect()) + } +} + +impl From>> for FieldVector> { + fn from(other: FieldVector>) -> Self { + FieldVector(other.into_iter().map(|x| x.0).collect()) + } +} + +impl From>> for FieldVector { + fn from(value: FieldVector>) -> Self { + FieldVector(value.0.into_iter().map(|x| x.0).collect()) + } +} + +impl IntoIterator for FieldVector { + type Item = T; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +/// Contains common functionality for vector operations that can be derived from those of the underlying `FpChip` +#[derive(Clone, Copy, Debug)] +pub struct FieldVectorChip<'fp, F: PrimeField, FpChip: FieldChip> { + pub fp_chip: &'fp FpChip, + _f: PhantomData, +} + +impl<'fp, F, FpChip> FieldVectorChip<'fp, F, FpChip> +where + F: PrimeField, + FpChip: PrimeFieldChip, + FpChip::FieldType: PrimeField, +{ + pub fn new(fp_chip: &'fp FpChip) -> Self { + Self { fp_chip, _f: PhantomData } + } + + pub fn gate(&self) -> &impl GateInstructions { + self.fp_chip.gate() + } + + pub fn fp_mul_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + fp_point: impl Into, + ) -> FieldVector + where + FP: Into, + { + let fp_point = fp_point.into(); + FieldVector( + a.into_iter().map(|a| self.fp_chip.mul_no_carry(ctx, a, fp_point.clone())).collect(), + ) + } + + pub fn select( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + sel: AssignedValue, + ) -> FieldVector + where + FpChip: Selectable, + { + FieldVector( + a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.select(ctx, a, b, sel)).collect(), + ) + } + + pub fn load_private( + &self, + ctx: &mut Context, + fe: FieldExt, + ) -> FieldVector + where + FieldExt: FieldExtConstructor, + { + FieldVector(fe.coeffs().into_iter().map(|a| self.fp_chip.load_private(ctx, a)).collect()) + } + + pub fn load_constant( + &self, + ctx: &mut Context, + c: FieldExt, + ) -> FieldVector + where + FieldExt: FieldExtConstructor, + { + FieldVector(c.coeffs().into_iter().map(|a| self.fp_chip.load_constant(ctx, a)).collect()) + } + + // signed overflow BigInt functions + pub fn add_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> FieldVector + where + A: Into, + B: Into, + { + FieldVector( + a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.add_no_carry(ctx, a, b)).collect(), + ) + } + + pub fn add_constant_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + c: FieldExt, + ) -> FieldVector + where + A: Into, + FieldExt: FieldExtConstructor, + { + let c_coeffs = c.coeffs(); + FieldVector( + a.into_iter() + .zip_eq(c_coeffs) + .map(|(a, c)| self.fp_chip.add_constant_no_carry(ctx, a, c)) + .collect(), + ) + } + + pub fn sub_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> FieldVector + where + A: Into, + B: Into, + { + FieldVector( + a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.sub_no_carry(ctx, a, b)).collect(), + ) + } + + pub fn negate( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> FieldVector { + FieldVector(a.into_iter().map(|a| self.fp_chip.negate(ctx, a)).collect()) + } + + pub fn scalar_mul_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + c: i64, + ) -> FieldVector + where + A: Into, + { + FieldVector(a.into_iter().map(|a| self.fp_chip.scalar_mul_no_carry(ctx, a, c)).collect()) + } + + pub fn scalar_mul_and_add_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + c: i64, + ) -> FieldVector + where + A: Into, + B: Into, + { + FieldVector( + a.into_iter() + .zip_eq(b) + .map(|(a, b)| self.fp_chip.scalar_mul_and_add_no_carry(ctx, a, b, c)) + .collect(), + ) + } + + pub fn check_carry_mod_to_zero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) { + for coeff in a { + self.fp_chip.check_carry_mod_to_zero(ctx, coeff); + } + } + + pub fn carry_mod( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> FieldVector { + FieldVector(a.into_iter().map(|coeff| self.fp_chip.carry_mod(ctx, coeff)).collect()) + } + + pub fn range_check( + &self, + ctx: &mut Context, + a: impl IntoIterator, + max_bits: usize, + ) where + A: Into, + { + for coeff in a { + self.fp_chip.range_check(ctx, coeff, max_bits); + } + } + + pub fn enforce_less_than( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> FieldVector { + FieldVector(a.into_iter().map(|coeff| self.fp_chip.enforce_less_than(ctx, coeff)).collect()) + } + + pub fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> AssignedValue { + let mut prev = None; + for a_coeff in a { + let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff); + if let Some(p) = prev { + let new = self.gate().and(ctx, coeff, p); + prev = Some(new); + } else { + prev = Some(coeff); + } + } + prev.unwrap() + } + + pub fn is_soft_nonzero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> AssignedValue { + let mut prev = None; + for a_coeff in a { + let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff); + if let Some(p) = prev { + let new = self.gate().or(ctx, coeff, p); + prev = Some(new); + } else { + prev = Some(coeff); + } + } + prev.unwrap() + } + + pub fn is_zero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> AssignedValue { + let mut prev = None; + for a_coeff in a { + let coeff = self.fp_chip.is_zero(ctx, a_coeff); + if let Some(p) = prev { + let new = self.gate().and(ctx, coeff, p); + prev = Some(new); + } else { + prev = Some(coeff); + } + } + prev.unwrap() + } + + pub fn is_equal_unenforced( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> AssignedValue { + let mut acc = None; + for (a_coeff, b_coeff) in a.into_iter().zip_eq(b) { + let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff); + if let Some(c) = acc { + acc = Some(self.gate().and(ctx, coeff, c)); + } else { + acc = Some(coeff); + } + } + acc.unwrap() + } + + pub fn assert_equal( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) { + for (a_coeff, b_coeff) in a.into_iter().zip(b) { + self.fp_chip.assert_equal(ctx, a_coeff, b_coeff) + } + } +} + +#[macro_export] +macro_rules! impl_field_ext_chip_common { + // Implementation of the functions in `FieldChip` trait for field extensions that can be derived from `FieldVectorChip` + () => { + fn native_modulus(&self) -> &BigUint { + self.0.fp_chip.native_modulus() + } + + fn range(&self) -> &Self::RangeChip { + self.0.fp_chip.range() + } + + fn limb_bits(&self) -> usize { + self.0.fp_chip.limb_bits() + } + + fn load_private(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint { + self.0.load_private(ctx, fe) + } + + fn load_constant(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint { + self.0.load_constant(ctx, fe) + } + + fn add_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + self.0.add_no_carry(ctx, a.into(), b.into()) + } + + fn add_constant_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + c: Self::FieldType, + ) -> Self::UnsafeFieldPoint { + self.0.add_constant_no_carry(ctx, a.into(), c) + } + + fn sub_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + self.0.sub_no_carry(ctx, a.into(), b.into()) + } + + fn negate(&self, ctx: &mut Context, a: Self::FieldPoint) -> Self::FieldPoint { + self.0.negate(ctx, a) + } + + fn scalar_mul_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + c: i64, + ) -> Self::UnsafeFieldPoint { + self.0.scalar_mul_no_carry(ctx, a.into(), c) + } + + fn scalar_mul_and_add_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + c: i64, + ) -> Self::UnsafeFieldPoint { + self.0.scalar_mul_and_add_no_carry(ctx, a.into(), b.into(), c) + } + + fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) { + self.0.check_carry_mod_to_zero(ctx, a); + } + + fn carry_mod(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) -> Self::FieldPoint { + self.0.carry_mod(ctx, a) + } + + fn range_check( + &self, + ctx: &mut Context, + a: impl Into, + max_bits: usize, + ) { + self.0.range_check(ctx, a.into(), max_bits) + } + + fn enforce_less_than( + &self, + ctx: &mut Context, + a: Self::FieldPoint, + ) -> Self::ReducedFieldPoint { + self.0.enforce_less_than(ctx, a) + } + + fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue { + let a = a.into(); + self.0.is_soft_zero(ctx, a) + } + + fn is_soft_nonzero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue { + let a = a.into(); + self.0.is_soft_nonzero(ctx, a) + } + + fn is_zero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue { + let a = a.into(); + self.0.is_zero(ctx, a) + } + + fn is_equal_unenforced( + &self, + ctx: &mut Context, + a: Self::ReducedFieldPoint, + b: Self::ReducedFieldPoint, + ) -> AssignedValue { + self.0.is_equal_unenforced(ctx, a, b) + } + + fn assert_equal( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) { + let a = a.into(); + let b = b.into(); + self.0.assert_equal(ctx, a, b) + } + }; +} diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index 739bffc7..af7050f9 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -62,14 +62,13 @@ fn ecdsa_test( let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); - let [m, r, s] = - [msghash, r, s].map(|x| fq_chip.load_private(ctx, FqChip::::fe_to_witness(&x))); + let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); let ecc_chip = EccChip::>::new(&fp_chip); - let pk = ecc_chip.load_private(ctx, (pk.x, pk.y)); + let pk = ecc_chip.load_private_unchecked(ctx, (pk.x, pk.y)); // test ECDSA let res = ecdsa_verify_no_pubkey_check::( - &fp_chip, ctx, &pk, &r, &s, &m, 4, 4, + &ecc_chip, ctx, pk, r, s, m, 4, 4, ); assert_eq!(res.value(), &F::one()); } diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index 27565d6a..27d4c1c6 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -19,10 +19,11 @@ use halo2_base::gates::builder::{ use halo2_base::gates::RangeChip; use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; use halo2_base::Context; +use rand::random; use rand_core::OsRng; use serde::{Deserialize, Serialize}; use std::fs::File; - +use test_case::test_case; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct CircuitParams { @@ -49,25 +50,18 @@ fn ecdsa_test( let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); - let [m, r, s] = - [msghash, r, s].map(|x| fq_chip.load_private(ctx, FqChip::::fe_to_witness(&x))); + let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); let ecc_chip = EccChip::>::new(&fp_chip); - let pk = ecc_chip.load_private(ctx, (pk.x, pk.y)); + let pk = ecc_chip.assign_point(ctx, pk); // test ECDSA let res = ecdsa_verify_no_pubkey_check::( - &fp_chip, ctx, &pk, &r, &s, &m, 4, 4, + &ecc_chip, ctx, pk, r, s, m, 4, 4, ); assert_eq!(res.value(), &F::one()); } - -fn random_parameters_ecdsa() -> ( - Fq, - Fq, - Fq, - Secp256k1Affine, - ) { +fn random_parameters_ecdsa() -> (Fq, Fq, Fq, Secp256k1Affine) { let sk = ::ScalarExt::random(OsRng); let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let msg_hash = ::ScalarExt::random(OsRng); @@ -78,24 +72,14 @@ fn random_parameters_ecdsa() -> ( let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); let x = r_point.x(); let x_bigint = fe_to_biguint(x); - - + let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - - return ( r, s, msg_hash, pubkey) + + (r, s, msg_hash, pubkey) } -fn custom_parameters_ecdsa( - sk: u64, - msg_hash: u64, - k: u64, -) -> ( - Fq, - Fq, - Fq, - Secp256k1Affine, - ){ +fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> (Fq, Fq, Fq, Secp256k1Affine) { let sk = ::ScalarExt::from(sk); let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let msg_hash = ::ScalarExt::from(msg_hash); @@ -106,21 +90,22 @@ fn custom_parameters_ecdsa( let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); let x = r_point.x(); let x_bigint = fe_to_biguint(x); - - + let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - - return (r, s, msg_hash, pubkey) -} + (r, s, msg_hash, pubkey) +} -fn ecdsa_circuit(r: Fq, s: Fq, msg_hash: Fq, pubkey: Secp256k1Affine, +fn ecdsa_circuit( + r: Fq, + s: Fq, + msg_hash: Fq, + pubkey: Secp256k1Affine, params: CircuitParams, stage: CircuitBuilderStage, break_points: Option, ) -> RangeCircuitBuilder { - let mut builder = match stage { CircuitBuilderStage::Mock => GateThreadBuilder::mock(), CircuitBuilderStage::Prover => GateThreadBuilder::prover(), @@ -128,7 +113,6 @@ fn ecdsa_circuit(r: Fq, s: Fq, msg_hash: Fq, pubkey: Secp256k1Affine, }; let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); - let circuit = match stage { CircuitBuilderStage::Mock => { @@ -145,69 +129,52 @@ fn ecdsa_circuit(r: Fq, s: Fq, msg_hash: Fq, pubkey: Secp256k1Affine, circuit } +#[test] +#[should_panic(expected = "assertion failed: `(left == right)`")] +fn test_ecdsa_msg_hash_zero() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(random::(), 0, random::()); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} -use rand::random; -#[cfg(test)] - #[test] - #[should_panic(expected = "assertion failed: `(left == right)`")] - fn test_ecdsa_msg_hash_zero() - { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(random::(),0, random::()); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); - } - - #[cfg(test)] - #[test] - #[should_panic(expected = "assertion failed: `(left == right)`")] - fn test_ecdsa_private_key_zero() - { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(0, random::(), random::()); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); - } - - - -#[cfg(test)] #[test] -fn test_ecdsa_random_valid_inputs() - { +#[should_panic(expected = "assertion failed: `(left == right)`")] +fn test_ecdsa_private_key_zero() { let path = "configs/secp256k1/ecdsa_circuit.config"; let params: CircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - let (r, s, msg_hash, pubkey) = random_parameters_ecdsa(); + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(0, random::(), random::()); - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } +#[test] +fn test_ecdsa_random_valid_inputs() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + let (r, s, msg_hash, pubkey) = random_parameters_ecdsa(); + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} -use test_case::test_case; -#[cfg(test)] #[test_case(1, 1, 1; "")] -fn test_ecdsa_custom_valid_inputs(sk: u64,msg_hash: u64, k: u64,) - { +fn test_ecdsa_custom_valid_inputs(sk: u64, msg_hash: u64, k: u64) { let path = "configs/secp256k1/ecdsa_circuit.config"; let params: CircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), @@ -216,16 +183,12 @@ fn test_ecdsa_custom_valid_inputs(sk: u64,msg_hash: u64, k: u64,) let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } - - -#[cfg(test)] #[test_case(1, 1, 1; "")] -fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64,msg_hash: u64, k: u64,) - { +fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64, msg_hash: u64, k: u64) { let path = "configs/secp256k1/ecdsa_circuit.config"; let params: CircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), @@ -235,6 +198,6 @@ fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64,msg_hash: u64, k: u64,) let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); let s = -s; - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None,); + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi.rs b/hashes/zkevm-keccak/src/keccak_packed_multi.rs index 3edc2e1a..55be8306 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi.rs @@ -285,6 +285,7 @@ impl CellManager { let column = if column_idx < self.columns.len() { self.columns[column_idx].advice } else { + assert!(column_idx == self.columns.len()); let advice = meta.advice_column(); let mut expr = 0.expr(); meta.create_gate("Query column", |meta| { @@ -337,7 +338,7 @@ impl CellManager { // Make sure all rows start at the same column let width = self.get_width(); #[cfg(debug_assertions)] - for row in self.rows.iter_mut() { + for row in self.rows.iter() { self.num_unused_cells += width - *row; } self.rows = vec![width; self.height]; @@ -1135,7 +1136,7 @@ impl KeccakCircuitConfig { for i in 0..5 { let input = scatter::expr(3, part_size_base) - 2.expr() * input[i].clone() + input[(i + 1) % 5].clone() - - input[(i + 2) % 5].clone().clone(); + - input[(i + 2) % 5].clone(); let output = output[i].clone(); meta.lookup("chi base", |_| { vec![(input.clone(), chi_base_table[0]), (output.clone(), chi_base_table[1])] @@ -1941,7 +1942,7 @@ pub fn keccak_phase0( .take(4) .map(|a| { pack_with_base::(&unpack(a[0]), 2) - .to_repr() + .to_bytes_le() .into_iter() .take(8) .collect::>() diff --git a/hashes/zkevm-keccak/src/util.rs b/hashes/zkevm-keccak/src/util.rs index 868c366c..b3e2e2b5 100644 --- a/hashes/zkevm-keccak/src/util.rs +++ b/hashes/zkevm-keccak/src/util.rs @@ -183,7 +183,7 @@ pub fn pack_part(bits: &[u8], info: &PartInfo) -> u64 { /// Unpack a sparse keccak word into bits in the range [0,BIT_SIZE[ pub fn unpack(packed: F) -> [u8; NUM_BITS_PER_WORD] { let mut bits = [0; NUM_BITS_PER_WORD]; - let packed = Word::from_little_endian(packed.to_repr().as_ref()); + let packed = Word::from_little_endian(packed.to_bytes_le().as_ref()); let mask = Word::from(BIT_SIZE - 1); for (idx, bit) in bits.iter_mut().enumerate() { *bit = ((packed >> (idx * BIT_COUNT)) & mask).as_u32() as u8; @@ -200,10 +200,10 @@ pub fn pack_u64(value: u64) -> F { /// Calculates a ^ b with a and b field elements pub fn field_xor(a: F, b: F) -> F { let mut bytes = [0u8; 32]; - for (idx, (a, b)) in a.to_repr().as_ref().iter().zip(b.to_repr().as_ref().iter()).enumerate() { - bytes[idx] = *a ^ *b; + for (idx, (a, b)) in a.to_bytes_le().into_iter().zip(b.to_bytes_le()).enumerate() { + bytes[idx] = a ^ b; } - F::from_repr(bytes).unwrap() + F::from_bytes_le(&bytes) } /// Returns the size (in bits) of each part size when splitting up a keccak word diff --git a/hashes/zkevm-keccak/src/util/constraint_builder.rs b/hashes/zkevm-keccak/src/util/constraint_builder.rs index 94f47c8c..bae9f4a4 100644 --- a/hashes/zkevm-keccak/src/util/constraint_builder.rs +++ b/hashes/zkevm-keccak/src/util/constraint_builder.rs @@ -53,7 +53,7 @@ impl BaseConstraintBuilder { pub(crate) fn validate_degree(&self, degree: usize, name: &'static str) { if self.max_degree > 0 { - debug_assert!( + assert!( degree <= self.max_degree, "Expression {} degree too high: {} > {}", name, diff --git a/hashes/zkevm-keccak/src/util/eth_types.rs b/hashes/zkevm-keccak/src/util/eth_types.rs index 3217f810..6fed74a5 100644 --- a/hashes/zkevm-keccak/src/util/eth_types.rs +++ b/hashes/zkevm-keccak/src/util/eth_types.rs @@ -71,7 +71,7 @@ impl ToScalar for U256 { fn to_scalar(&self) -> Option { let mut bytes = [0u8; 32]; self.to_little_endian(&mut bytes); - F::from_repr(bytes).into() + Some(F::from_bytes_le(&bytes)) } } @@ -113,7 +113,7 @@ impl ToScalar for Address { let mut bytes = [0u8; 32]; bytes[32 - Self::len_bytes()..].copy_from_slice(self.as_bytes()); bytes.reverse(); - F::from_repr(bytes).into() + Some(F::from_bytes_le(&bytes)) } } From 73429156f2b3eab529b9df87241ecb71f235d174 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 6 Jun 2023 12:33:49 -0500 Subject: [PATCH 11/12] Fix/fb msm zero (#77) * fix: fixed_base scalar multiply for [-1]P * feat: use `multi_scalar_multiply` instead of `scalar_multiply` * to reduce code maintanence / redundancy * fix: add back scalar_multiply using any_point * feat: remove flag from variable base `scalar_multiply` * feat: add scalar multiply tests for secp256k1 * fix: variable scalar_multiply last select * Fix/msm tests output identity (#75) * fixed base msm tests for output infinity * fixed base msm tests for output infinity --------- Co-authored-by: yulliakot * feat: add tests and update CI --------- Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> Co-authored-by: yulliakot --- .github/workflows/ci.yml | 8 +- .../configs/bn254/bench_fixed_msm.t.config | 5 + halo2-ecc/configs/bn254/bench_msm.t.config | 5 + .../configs/bn254/bench_pairing.t.config | 5 + halo2-ecc/src/bn254/tests/fixed_base_msm.rs | 28 ++- halo2-ecc/src/bn254/tests/mod.rs | 44 +++-- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 15 -- .../tests/msm_sum_infinity_fixed_base.rs | 183 ++++++++++++++++++ halo2-ecc/src/ecc/ecdsa.rs | 8 +- halo2-ecc/src/ecc/fixed_base.rs | 71 +++---- halo2-ecc/src/ecc/mod.rs | 58 +++--- halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 14 +- halo2-ecc/src/secp256k1/tests/mod.rs | 160 +++++++++++++++ 13 files changed, 473 insertions(+), 131 deletions(-) create mode 100644 halo2-ecc/configs/bn254/bench_fixed_msm.t.config create mode 100644 halo2-ecc/configs/bn254/bench_msm.t.config create mode 100644 halo2-ecc/configs/bn254/bench_pairing.t.config create mode 100644 halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4fedf24b..564ddb6f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,11 +27,12 @@ jobs: cd halo2-ecc cargo test -- --test-threads=1 test_fp cargo test -- test_ecc - cargo test -- test_secp256k1_ecdsa + cargo test -- test_secp cargo test -- test_ecdsa cargo test -- test_ec_add - cargo test -- test_fixed_base_msm + cargo test -- test_fixed cargo test -- test_msm + cargo test -- test_fb cargo test -- test_pairing cd .. - name: Run halo2-ecc tests real prover @@ -40,7 +41,10 @@ jobs: cargo test --release -- test_fp_assert_eq cargo test --release -- --nocapture bench_secp256k1_ecdsa cargo test --release -- --nocapture bench_ec_add + mv configs/bn254/bench_fixed_msm.t.config configs/bn254/bench_fixed_msm.config cargo test --release -- --nocapture bench_fixed_base_msm + mv configs/bn254/bench_msm.t.config configs/bn254/bench_msm.config cargo test --release -- --nocapture bench_msm + mv configs/bn254/bench_pairing.t.config configs/bn254/bench_pairing.config cargo test --release -- --nocapture bench_pairing cd .. diff --git a/halo2-ecc/configs/bn254/bench_fixed_msm.t.config b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config new file mode 100644 index 00000000..61db5d6d --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":17,"num_advice":83,"num_lookup_advice":9,"num_fixed":7,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":18,"num_advice":42,"num_lookup_advice":5,"num_fixed":4,"lookup_bits":17,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":2,"num_fixed":2,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"radix":0,"clump_factor":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_msm.t.config b/halo2-ecc/configs/bn254/bench_msm.t.config new file mode 100644 index 00000000..bd4c4318 --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_msm.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":16,"num_advice":170,"num_lookup_advice":23,"num_fixed":1,"lookup_bits":15,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"window_bits":4} +{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_pairing.t.config b/halo2-ecc/configs/bn254/bench_pairing.t.config new file mode 100644 index 00000000..d76ebad1 --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_pairing.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":15,"num_advice":105,"num_lookup_advice":14,"num_fixed":1,"lookup_bits":14,"limb_bits":90,"num_limbs":3} +{"strategy":"Simple","degree":17,"num_advice":25,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3} +{"strategy":"Simple","degree":18,"num_advice":13,"num_lookup_advice":2,"num_fixed":1,"lookup_bits":17,"limb_bits":88,"num_limbs":3} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3} +{"strategy":"Simple","degree":20,"num_advice":3,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index a8f039c2..0283f672 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -23,7 +23,7 @@ use itertools::Itertools; use rand_core::OsRng; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { +struct FixedMSMCircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -39,7 +39,7 @@ struct MSMCircuitParams { fn fixed_base_msm_test( builder: &mut GateThreadBuilder, - params: MSMCircuitParams, + params: FixedMSMCircuitParams, bases: Vec, scalars: Vec, ) { @@ -68,7 +68,7 @@ fn fixed_base_msm_test( } fn random_fixed_base_msm_circuit( - params: MSMCircuitParams, + params: FixedMSMCircuitParams, bases: Vec, // bases are fixed in vkey so don't randomly generate stage: CircuitBuilderStage, break_points: Option, @@ -102,7 +102,7 @@ fn random_fixed_base_msm_circuit( #[test] fn test_fixed_base_msm() { let path = "configs/bn254/fixed_msm_circuit.config"; - let params: MSMCircuitParams = serde_json::from_reader( + let params: FixedMSMCircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); @@ -112,6 +112,23 @@ fn test_fixed_base_msm() { MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } +#[test] +fn test_fixed_msm_minus_1() { + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + let base = G1Affine::random(OsRng); + let k = params.degree as usize; + let mut builder = GateThreadBuilder::mock(); + fixed_base_msm_test(&mut builder, params, vec![base], vec![-Fr::one()]); + + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + #[test] fn bench_fixed_base_msm() -> Result<(), Box> { let config_path = "configs/bn254/bench_fixed_msm.config"; @@ -126,7 +143,8 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { - let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); + let bench_params: FixedMSMCircuitParams = + serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); let rng = OsRng; diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index c4ffb393..172300a1 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -1,20 +1,23 @@ #![allow(non_snake_case)] use super::pairing::PairingChip; use super::*; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, - plonk::*, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, +use crate::{ecc::EccChip, fields::PrimeField}; +use crate::{ + fields::FpStrategy, + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, + plonk::*, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + transcript::{Blake2bRead, Blake2bWrite, Challenge255}, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; -use crate::{ecc::EccChip, fields::PrimeField}; use ark_std::{end_timer, start_timer}; use group::Curve; use halo2_base::utils::fe_to_biguint; @@ -24,5 +27,20 @@ use std::io::Write; pub mod ec_add; pub mod fixed_base_msm; pub mod msm; -pub mod pairing; pub mod msm_sum_infinity; +pub mod msm_sum_infinity_fixed_base; +pub mod pairing; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct MSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + window_bits: usize, +} diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs index 052edea4..600a4931 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -1,4 +1,3 @@ -use crate::fields::FpStrategy; use ff::PrimeField; use halo2_base::gates::{ builder::{ @@ -11,20 +10,6 @@ use std::fs::File; use super::*; -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - batch_size: usize, - window_bits: usize, -} - fn msm_test( builder: &mut GateThreadBuilder, params: MSMCircuitParams, diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs new file mode 100644 index 00000000..6cf96c7f --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -0,0 +1,183 @@ +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; +use rand_core::OsRng; +use std::fs::File; + +use super::*; + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases; + //.iter() + //.map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + //.collect::>(); + + let msm = ecc_chip.fixed_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases_assigned + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_fb_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index d7406a17..ca0b111b 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -3,8 +3,7 @@ use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, use crate::bigint::{big_is_equal, big_less_than, FixedOverflowInteger, ProperCrtUint}; use crate::fields::{fp::FpChip, FieldChip, PrimeField}; -use super::{fixed_base, EccChip}; -use super::{scalar_multiply, EcPoint}; +use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // CF is the coordinate field of GA // SF is the scalar field of GA // p = coordinate field modulus @@ -12,6 +11,7 @@ use super::{scalar_multiply, EcPoint}; // Only valid when p is very close to n in size (e.g. for Secp256k1) // Assumes `r, s` are proper CRT integers /// **WARNING**: Only use this function if `1 / (p - n)` is very small (e.g., < 2-100) +/// `pubkey` should not be the identity point pub fn ecdsa_verify_no_pubkey_check( chip: &EccChip>, ctx: &mut Context, @@ -49,16 +49,14 @@ where u1.limbs().to_vec(), base_chip.limb_bits, fixed_window_bits, - true, // we can call it with scalar_is_safe = true because of the u1_small check below ); - let u2_mul = scalar_multiply( + let u2_mul = scalar_multiply::<_, _, GA>( base_chip, ctx, pubkey, u2.limbs().to_vec(), base_chip.limb_bits, var_window_bits, - true, // we can call it with scalar_is_safe = true because of the u2_small check below ); // check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index dc67b8d6..5dfba754 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,6 +1,6 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; -use crate::ecc::ec_sub_strict; +use crate::ecc::{ec_sub_strict, load_random_point}; use crate::fields::{FieldChip, PrimeField, Selectable}; use group::Curve; use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; @@ -17,8 +17,6 @@ use std::cmp::min; /// # Assumptions /// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) /// - `scalar > 0` -/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) -/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) /// - `max_bits <= modulus::.bits()` pub fn scalar_multiply( chip: &FC, @@ -27,7 +25,6 @@ pub fn scalar_multiply( scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where F: PrimeField, @@ -87,29 +84,19 @@ where let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = ctx.load_zero(); + let any_point = load_random_point::(chip, ctx); + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied()); // are we just adding a window of all 0s? if so, skip let is_zero_window = chip.gate().is_zero(ctx, bit_sum); - let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, !scalar_is_safe); - let zero_sum = ec_select(chip, ctx, curr_point, sum, is_zero_window); - Some(ec_select(chip, ctx, zero_sum, add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = chip.gate().not(ctx, is_zero_window); - chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) + curr_point = { + let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, true); + ec_select(chip, ctx, curr_point, sum, is_zero_window) }; } - curr_point.unwrap() + ec_sub_strict(chip, ctx, curr_point, any_point) } // basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation @@ -120,7 +107,7 @@ where /// * `scalars[i].len() = scalars[j].len()` for all `i,j` /// * `points` are all on the curve /// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand -/// * The integer value of `scalars[i]` is less than the order of `points[i]` (some constraints may fail otherwise) +/// * The integer value of `scalars[i]` is less than the order of `points[i]` /// * Output may be point at infinity, in which case (0, 0) is returned pub fn msm_par( chip: &EccChip, @@ -153,6 +140,7 @@ where .flat_map(|point| -> Vec<_> { let base_pt = point.to_curve(); // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1} + // EXCEPT cached_points[idx][0] = points[idx] let mut increment = base_pt; (0..num_windows) .flat_map(|i| { @@ -178,8 +166,9 @@ where C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); + let ctx = builder.main(phase); + let any_point = chip.load_random_point::(ctx); - let zero = builder.main(phase).load_zero(); let scalar_mults = parallelize_in( phase, builder, @@ -202,41 +191,29 @@ where }) .collect::>(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = zero; + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let is_zero_window = { let sum = field_chip.gate().sum(ctx, bit_window.iter().copied()); field_chip.gate().is_zero(ctx, sum) }; - let add_point = - ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - // We don't need strict mode because we assume scalars[i] is less than the order of points[i] - let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(field_chip, ctx, curr_point, sum, is_zero_window); - Some(ec_select(field_chip, ctx, zero_sum, add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = field_chip.gate().not(ctx, is_zero_window); - field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) + curr_point = { + let add_point = + ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, true); + ec_select(field_chip, ctx, curr_point, sum, is_zero_window) }; } - (curr_point.unwrap(), is_started) + curr_point }, ); let ctx = builder.main(phase); // sum `scalar_mults` but take into account possiblity of identity points - let any_point = chip.load_random_point::(ctx); - let mut acc = any_point.clone(); - for (point, is_not_identity) in scalar_mults { + let any_point2 = chip.load_random_point::(ctx); + let mut acc = any_point2.clone(); + for point in scalar_mults { let new_acc = chip.add_unequal(ctx, &acc, point, true); - acc = chip.select(ctx, new_acc, acc, is_not_identity); + acc = chip.sub_unequal(ctx, new_acc, &any_point, true); } - ec_sub_strict(field_chip, ctx, acc, any_point) + ec_sub_strict(field_chip, ctx, acc, any_point2) } diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index 208eee11..87b383bd 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -3,9 +3,10 @@ use crate::fields::{fp::FpChip, FieldChip, PrimeField, Selectable}; use crate::halo2_proofs::arithmetic::CurveAffine; use group::{Curve, Group}; use halo2_base::gates::builder::GateThreadBuilder; +use halo2_base::utils::modulus; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{modulus, CurveAffineExt}, + utils::CurveAffineExt, AssignedValue, Context, }; use itertools::Itertools; @@ -480,26 +481,26 @@ where /// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` /// /// # Assumptions -/// - `P` is not the point at infinity -/// - `scalar > 0` -/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) -/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) +/// - `window_bits != 0` +/// - The order of `P` is at least `2^{window_bits}` (in particular, `P` is not the point at infinity) +/// - The curve has no points of order 2. /// - `scalar_i < 2^{max_bits} for all i` /// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` -pub fn scalar_multiply( +pub fn scalar_multiply( chip: &FC, ctx: &mut Context, P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where FC: FieldChip + Selectable, + C: CurveAffineExt, { assert!(!scalar.is_empty()); assert!((max_bits as u64) <= modulus::().bits()); + assert!(window_bits != 0); let total_bits = max_bits * scalar.len(); let num_windows = (total_bits + window_bits - 1) / window_bits; @@ -517,7 +518,7 @@ where // is_started[idx] holds whether there is a 1 in bits with index at least (rounded_bitlen - idx) let mut is_started = Vec::with_capacity(rounded_bitlen); is_started.resize(rounded_bitlen - total_bits + 1, zero_cell); - for idx in 1..total_bits { + for idx in 1..=total_bits { let or = chip.gate().or(ctx, *is_started.last().unwrap(), rounded_bits[total_bits - idx]); is_started.push(or); } @@ -534,22 +535,23 @@ where is_zero_window.push(is_zero); } - // cached_points[idx] stores idx * P, with cached_points[0] = P + let any_point = load_random_point::(chip, ctx); + // cached_points[idx] stores idx * P, with cached_points[0] = any_point let cache_size = 1usize << window_bits; let mut cached_points = Vec::with_capacity(cache_size); - cached_points.push(P.clone()); + cached_points.push(any_point); cached_points.push(P.clone()); for idx in 2..cache_size { if idx == 2 { let double = ec_double(chip, ctx, &P); cached_points.push(double); } else { - let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, !scalar_is_safe); + let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, false); cached_points.push(new_point); } } - // if all the starting window bits are 0, get start_point = P + // if all the starting window bits are 0, get start_point = any_point let mut curr_point = ec_select_from_bits( chip, ctx, @@ -569,13 +571,17 @@ where &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, !scalar_is_safe); + // if is_zero_window[idx] = true, add_point = any_point. We only need any_point to avoid divide by zero in add_unequal + // if is_zero_window = true and is_started = false, then mult_point = 2^window_bits * any_point. Since window_bits != 0, we have mult_point != +- any_point + let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, true); let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]); curr_point = ec_select(chip, ctx, is_started_point, add_point, is_started[window_bits * idx]); } - curr_point + // if at the end, return identity point (0,0) if still not started + let zero = chip.load_constant(ctx, FC::FieldType::zero()); + ec_select(chip, ctx, curr_point, EcPoint::new(zero.clone(), zero), *is_started.last().unwrap()) } /// Checks that `P` is indeed a point on the elliptic curve `C`. @@ -1018,24 +1024,18 @@ where } /// See [`scalar_multiply`] for more details. - pub fn scalar_mult( + pub fn scalar_mult( &self, ctx: &mut Context, P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, - ) -> EcPoint { - scalar_multiply::( - self.field_chip, - ctx, - P, - scalar, - max_bits, - window_bits, - scalar_is_safe, - ) + ) -> EcPoint + where + C: CurveAffineExt, + { + scalar_multiply::(self.field_chip, ctx, P, scalar, max_bits, window_bits) } // default for most purposes @@ -1049,14 +1049,13 @@ where ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, FC: Selectable, { // window_bits = 4 is optimal from empirical observations self.variable_base_msm_in::(thread_pool, P, scalars, max_bits, 4, 0) } - // TODO: put a check in place that scalar is < modulus of C::Scalar + // TODO: add asserts to validate input assumptions described in docs pub fn variable_base_msm_in( &self, builder: &mut GateThreadBuilder, @@ -1068,7 +1067,6 @@ where ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, FC: Selectable, { #[cfg(feature = "display")] @@ -1115,7 +1113,6 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where C: CurveAffineExt, @@ -1128,7 +1125,6 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar, max_bits, window_bits, - scalar_is_safe, ) } diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index 27d4c1c6..45e251f3 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -1,5 +1,4 @@ #![allow(non_snake_case)] -use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, dev::MockProver, @@ -21,21 +20,10 @@ use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; use halo2_base::Context; use rand::random; use rand_core::OsRng; -use serde::{Deserialize, Serialize}; use std::fs::File; use test_case::test_case; -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct CircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, -} +use super::CircuitParams; fn ecdsa_test( ctx: &mut Context, diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index cdd58dd8..803ac232 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1,2 +1,162 @@ +#![allow(non_snake_case)] +use std::fs::File; + +use ff::Field; +use group::Curve; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + halo2_proofs::{ + dev::MockProver, + halo2curves::{ + bn256::Fr, + secp256k1::{Fq, Secp256k1Affine}, + }, + }, + utils::{biguint_to_fe, fe_to_biguint, BigPrimeField}, + Context, +}; +use num_bigint::BigUint; +use rand_core::OsRng; +use serde::{Deserialize, Serialize}; + +use crate::{ + ecc::EccChip, + fields::{FieldChip, FpStrategy}, + secp256k1::{FpChip, FqChip}, +}; + pub mod ecdsa; pub mod ecdsa_tests; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct CircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, +} + +fn sm_test( + ctx: &mut Context, + params: CircuitParams, + base: Secp256k1Affine, + scalar: Fq, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::>::new(&fp_chip); + + let s = fq_chip.load_private(ctx, scalar); + let P = ecc_chip.assign_point(ctx, base); + + let sm = ecc_chip.scalar_mult::( + ctx, + P, + s.limbs().to_vec(), + fq_chip.limb_bits, + window_bits, + ); + + let sm_answer = (base * scalar).to_affine(); + + let sm_x = sm.x.value(); + let sm_y = sm.y.value(); + assert_eq!(sm_x, fe_to_biguint(&sm_answer.x)); + assert_eq!(sm_y, fe_to_biguint(&sm_answer.y)); +} + +fn sm_circuit( + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + base: Secp256k1Affine, + scalar: Fq, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = GateThreadBuilder::new(stage == CircuitBuilderStage::Prover); + + sm_test(builder.main(0), params, base, scalar, 4); + + match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + } +} + +#[test] +fn test_secp_sm_random() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = sm_circuit( + params, + CircuitBuilderStage::Mock, + None, + Secp256k1Affine::random(OsRng), + Fq::random(OsRng), + ); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_secp_sm_minus_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let mut s = -Fq::one(); + let mut n = fe_to_biguint(&s); + loop { + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + if &n % BigUint::from(2usize) == BigUint::from(0usize) { + break; + } + n /= 2usize; + s = biguint_to_fe(&n); + } +} + +#[test] +fn test_secp_sm_0_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let s = Fq::zero(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + + let s = Fq::one(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} From 2c9233e01434138213aef29767b6adfb1595d737 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 6 Jun 2023 10:40:47 -0700 Subject: [PATCH 12/12] chore: trigger CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 564ddb6f..d6f2750d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: push: branches: ["main", "release-0.3.0"] pull_request: - branches: ["main"] + branches: ["main", "release-0.3.0"] env: CARGO_TERM_COLOR: always