Skip to content

Commit

Permalink
Reduce duplicate code across different curve cycle providers (lurk-la…
Browse files Browse the repository at this point in the history
…ng#255)

* refactor: impl folding macro

* refactor: generalize curve test

* chore: rename impl_folding to impl_engine
  • Loading branch information
ashWhiteHat authored Nov 20, 2023
1 parent 75cabf0 commit 3e05f5d
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 186 deletions.
36 changes: 1 addition & 35 deletions src/provider/bn256_grumpkin.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! This module implements the Nova traits for `bn256::Point`, `bn256::Scalar`, `grumpkin::Point`, `grumpkin::Scalar`.
use crate::{
impl_traits,
impl_engine, impl_traits,
provider::{
cpu_best_multiexp,
keccak::Keccak256Transcript,
Expand Down Expand Up @@ -69,37 +69,3 @@ impl_traits!(
"30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47",
"30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001"
);

#[cfg(test)]
mod tests {
use super::*;
type G = bn256::Point;

fn from_label_serial(label: &'static [u8], n: usize) -> Vec<Bn256Affine> {
let mut shake = Shake256::default();
shake.update(label);
let mut reader = shake.finalize_xof();
let mut ck = Vec::new();
for _ in 0..n {
let mut uniform_bytes = [0u8; 32];
reader.read_exact(&mut uniform_bytes).unwrap();
let hash = bn256::Point::hash_to_curve("from_uniform_bytes");
ck.push(hash(&uniform_bytes).to_affine());
}
ck
}

#[test]
fn test_from_label() {
let label = b"test_from_label";
for n in [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021,
] {
let ck_par = <G as DlogGroup>::from_label(label, n);
let ck_ser = from_label_serial(label, n);
assert_eq!(ck_par.len(), n);
assert_eq!(ck_ser.len(), n);
assert_eq!(ck_par, ck_ser);
}
}
}
136 changes: 103 additions & 33 deletions src/provider/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,29 +229,14 @@ macro_rules! impl_traits {
$order_str:literal,
$base_str:literal
) => {
impl Engine for $engine {
type Base = $name::Base;
type Scalar = $name::Scalar;
type GE = $name::Point;
type RO = PoseidonRO<Self::Base, Self::Scalar>;
type ROCircuit = PoseidonROCircuit<Self::Base>;
type TE = Keccak256Transcript<Self>;
type CE = CommitmentEngine<Self>;
}

impl Group for $name::Point {
type Base = $name::Base;
type Scalar = $name::Scalar;

fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) {
let A = $name::Point::a();
let B = $name::Point::b();
let order = BigInt::from_str_radix($order_str, 16).unwrap();
let base = BigInt::from_str_radix($base_str, 16).unwrap();

(A, B, order, base)
}
}
impl_engine!(
$engine,
$name,
$name_compressed,
$name_curve,
$order_str,
$base_str
);

impl DlogGroup for $name::Point {
type CompressedGroupElement = $name_compressed;
Expand Down Expand Up @@ -335,10 +320,11 @@ macro_rules! impl_traits {
}
}

impl PrimeFieldExt for $name::Scalar {
fn from_uniform(bytes: &[u8]) -> Self {
let bytes_arr: [u8; 64] = bytes.try_into().unwrap();
$name::Scalar::from_uniform_bytes(&bytes_arr)
impl CompressedGroup for $name_compressed {
type GroupElement = $name::Point;

fn decompress(&self) -> Option<$name::Point> {
Some($name_curve::from_bytes(&self).unwrap())
}
}

Expand All @@ -347,12 +333,48 @@ macro_rules! impl_traits {
self.as_ref().to_vec()
}
}
};
}

impl CompressedGroup for $name_compressed {
type GroupElement = $name::Point;
/// Nova folding circuit engine and curve group ops
#[macro_export]
macro_rules! impl_engine {
(
$engine:ident,
$name:ident,
$name_compressed:ident,
$name_curve:ident,
$order_str:literal,
$base_str:literal
) => {
impl Engine for $engine {
type Base = $name::Base;
type Scalar = $name::Scalar;
type GE = $name::Point;
type RO = PoseidonRO<Self::Base, Self::Scalar>;
type ROCircuit = PoseidonROCircuit<Self::Base>;
type TE = Keccak256Transcript<Self>;
type CE = CommitmentEngine<Self>;
}

fn decompress(&self) -> Option<$name::Point> {
Some($name_curve::from_bytes(&self).unwrap())
impl Group for $name::Point {
type Base = $name::Base;
type Scalar = $name::Scalar;

fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) {
let A = $name::Point::a();
let B = $name::Point::b();
let order = BigInt::from_str_radix($order_str, 16).unwrap();
let base = BigInt::from_str_radix($base_str, 16).unwrap();

(A, B, order, base)
}
}

impl PrimeFieldExt for $name::Scalar {
fn from_uniform(bytes: &[u8]) -> Self {
let bytes_arr: [u8; 64] = bytes.try_into().unwrap();
$name::Scalar::from_uniform_bytes(&bytes_arr)
}
}

Expand All @@ -371,11 +393,44 @@ mod tests {
use crate::provider::{
bn256_grumpkin::{bn256, grumpkin},
secp_secq::{secp256k1, secq256k1},
DlogGroup,
};
use group::{ff::Field, Group};
use halo2curves::CurveAffine;
use digest::{ExtendableOutput, Update};
use group::{ff::Field, Curve, Group};
use halo2curves::{CurveAffine, CurveExt};
use pasta_curves::{pallas, vesta};
use rand_core::OsRng;
use sha3::Shake256;
use std::io::Read;

macro_rules! impl_cycle_pair_test {
($curve:ident) => {
fn from_label_serial(label: &'static [u8], n: usize) -> Vec<$curve::Affine> {
let mut shake = Shake256::default();
shake.update(label);
let mut reader = shake.finalize_xof();
(0..n)
.map(|_| {
let mut uniform_bytes = [0u8; 32];
reader.read_exact(&mut uniform_bytes).unwrap();
let hash = $curve::Point::hash_to_curve("from_uniform_bytes");
hash(&uniform_bytes).to_affine()
})
.collect()
}

let label = b"test_from_label";
for n in [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021,
] {
let ck_par = <$curve::Point as DlogGroup>::from_label(label, n);
let ck_ser = from_label_serial(label, n);
assert_eq!(ck_par.len(), n);
assert_eq!(ck_ser.len(), n);
assert_eq!(ck_par, ck_ser);
}
};
}

fn test_msm_with<F: Field, A: CurveAffine<ScalarExt = F>>() {
let n = 8;
Expand Down Expand Up @@ -403,4 +458,19 @@ mod tests {
test_msm_with::<secp256k1::Scalar, secp256k1::Affine>();
test_msm_with::<secq256k1::Scalar, secq256k1::Affine>();
}

#[test]
fn test_bn256_from_label() {
impl_cycle_pair_test!(bn256);
}

#[test]
fn test_pallas_from_label() {
impl_cycle_pair_test!(pallas);
}

#[test]
fn test_secp256k1_from_label() {
impl_cycle_pair_test!(secp256k1);
}
}
99 changes: 16 additions & 83 deletions src/provider/pasta.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! This module implements the Nova traits for `pallas::Point`, `pallas::Scalar`, `vesta::Point`, `vesta::Scalar`.
use crate::{
impl_engine,
provider::{
cpu_best_multiexp,
keccak::Keccak256Transcript,
Expand Down Expand Up @@ -68,62 +69,41 @@ macro_rules! impl_traits {
$order_str:literal,
$base_str:literal
) => {
impl Engine for $engine {
type Base = $name::Base;
type Scalar = $name::Scalar;
type GE = $name::Point;
type RO = PoseidonRO<Self::Base, Self::Scalar>;
type ROCircuit = PoseidonROCircuit<Self::Base>;
type TE = Keccak256Transcript<Self>;
type CE = CommitmentEngine<Self>;
}

impl Group for $name::Point {
type Base = $name::Base;
type Scalar = $name::Scalar;

fn group_params() -> (Self::Base, Self::Base, BigInt, BigInt) {
let A = $name::Point::a();
let B = $name::Point::b();
let order = BigInt::from_str_radix($order_str, 16).unwrap();
let base = BigInt::from_str_radix($base_str, 16).unwrap();

(A, B, order, base)
}
}
impl_engine!(
$engine,
$name,
$name_compressed,
$name_curve,
$order_str,
$base_str
);

impl DlogGroup for $name::Point {
type CompressedGroupElement = $name_compressed;
type PreprocessedGroupElement = $name::Affine;

#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
fn vartime_multiscalar_mul(
scalars: &[Self::Scalar],
bases: &[Self::PreprocessedGroupElement],
) -> Self {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
if scalars.len() >= 128 {
pasta_msm::$name(bases, scalars)
} else {
cpu_best_multiexp(scalars, bases)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
cpu_best_multiexp(scalars, bases)
}

#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
fn vartime_multiscalar_mul(
scalars: &[Self::Scalar],
bases: &[Self::PreprocessedGroupElement],
) -> Self {
cpu_best_multiexp(scalars, bases)
fn preprocessed(&self) -> Self::PreprocessedGroupElement {
self.to_affine()
}

fn compress(&self) -> Self::CompressedGroupElement {
$name_compressed::new(self.to_bytes())
}

fn preprocessed(&self) -> Self::PreprocessedGroupElement {
self.to_affine()
}

fn from_label(label: &'static [u8], n: usize) -> Vec<Self::PreprocessedGroupElement> {
let mut shake = Shake256::default();
shake.update(label);
Expand Down Expand Up @@ -184,19 +164,6 @@ macro_rules! impl_traits {
}
}

impl PrimeFieldExt for $name::Scalar {
fn from_uniform(bytes: &[u8]) -> Self {
let bytes_arr: [u8; 64] = bytes.try_into().unwrap();
$name::Scalar::from_uniform_bytes(&bytes_arr)
}
}

impl<G: DlogGroup> TranscriptReprTrait<G> for $name_compressed {
fn to_transcript_bytes(&self) -> Vec<u8> {
self.repr.to_vec()
}
}

impl CompressedGroup for $name_compressed {
type GroupElement = $name::Point;

Expand All @@ -205,9 +172,9 @@ macro_rules! impl_traits {
}
}

impl<G: Group> TranscriptReprTrait<G> for $name::Scalar {
impl<G: DlogGroup> TranscriptReprTrait<G> for $name_compressed {
fn to_transcript_bytes(&self) -> Vec<u8> {
self.to_repr().to_vec()
self.repr.to_vec()
}
}
};
Expand All @@ -232,37 +199,3 @@ impl_traits!(
"40000000000000000000000000000000224698fc094cf91b992d30ed00000001",
"40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001"
);

#[cfg(test)]
mod tests {
use super::*;
type G = <PallasEngine as Engine>::GE;

fn from_label_serial(label: &'static [u8], n: usize) -> Vec<EpAffine> {
let mut shake = Shake256::default();
shake.update(label);
let mut reader = shake.finalize_xof();
let mut ck = Vec::new();
for _ in 0..n {
let mut uniform_bytes = [0u8; 32];
reader.read_exact(&mut uniform_bytes).unwrap();
let hash = Ep::hash_to_curve("from_uniform_bytes");
ck.push(hash(&uniform_bytes).to_affine());
}
ck
}

#[test]
fn test_from_label() {
let label = b"test_from_label";
for n in [
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 1021,
] {
let ck_par = <G as DlogGroup>::from_label(label, n);
let ck_ser = from_label_serial(label, n);
assert_eq!(ck_par.len(), n);
assert_eq!(ck_ser.len(), n);
assert_eq!(ck_par, ck_ser);
}
}
}
Loading

0 comments on commit 3e05f5d

Please sign in to comment.