diff --git a/Cargo.lock b/Cargo.lock index ff2f89b002..cdb535e480 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1314,7 +1314,10 @@ dependencies = [ "ark-serialize", "ark-std", "ark-test-curves", + "derivative", + "num-bigint", "rand", + "wasm-bindgen", ] [[package]] diff --git a/curves/Cargo.toml b/curves/Cargo.toml index 2116b66c86..f97a60cbae 100644 --- a/curves/Cargo.toml +++ b/curves/Cargo.toml @@ -12,10 +12,13 @@ license = "Apache-2.0" [dependencies] ark-ec.workspace = true ark-ff.workspace = true +ark-serialize.workspace = true +wasm-bindgen.workspace = true +num-bigint.workspace = true +derivative = { version = "2.0", features = ["use_core"] } [dev-dependencies] -rand.workspace = true +rand.workspace = true ark-test-curves.workspace = true ark-algebra-test-templates.workspace = true -ark-serialize.workspace = true ark-std.workspace = true diff --git a/curves/src/pasta/mod.rs b/curves/src/pasta/mod.rs index abd4207fa3..4d3eb86cfa 100644 --- a/curves/src/pasta/mod.rs +++ b/curves/src/pasta/mod.rs @@ -1,5 +1,6 @@ pub mod curves; pub mod fields; +pub mod wasm_friendly; pub use curves::{ pallas::{Pallas, PallasParameters, ProjectivePallas}, diff --git a/curves/src/pasta/wasm_friendly/backend9.rs b/curves/src/pasta/wasm_friendly/backend9.rs new file mode 100644 index 0000000000..1d6808175f --- /dev/null +++ b/curves/src/pasta/wasm_friendly/backend9.rs @@ -0,0 +1,213 @@ +/** + * Implementation of `FpBackend` for N=9, using 29-bit limbs represented by `u32`s. + */ +use super::bigint32::BigInt; +use super::wasm_fp::{Fp, FpBackend}; + +type B = [u32; 9]; +type B64 = [u64; 9]; + +const SHIFT: u32 = 29; +const MASK: u32 = (1 << SHIFT) - 1; + +const SHIFT64: u64 = SHIFT as u64; +const MASK64: u64 = MASK as u64; + +pub const fn from_64x4(pa: [u64; 4]) -> [u32; 9] { + let mut p = [0u32; 9]; + p[0] = (pa[0] & MASK64) as u32; + p[1] = ((pa[0] >> 29) & MASK64) as u32; + p[2] = (((pa[0] >> 58) | (pa[1] << 6)) & MASK64) as u32; + p[3] = ((pa[1] >> 23) & MASK64) as u32; + p[4] = (((pa[1] >> 52) | (pa[2] << 12)) & MASK64) as u32; + p[5] = ((pa[2] >> 17) & MASK64) as u32; + p[6] = (((pa[2] >> 46) | (pa[3] << 18)) & MASK64) as u32; + p[7] = ((pa[3] >> 11) & MASK64) as u32; + p[8] = (pa[3] >> 40) as u32; + p +} +pub const fn to_64x4(pa: [u32; 9]) -> [u64; 4] { + let mut p = [0u64; 4]; + p[0] = pa[0] as u64; + p[0] |= (pa[1] as u64) << 29; + p[0] |= (pa[2] as u64) << 58; + p[1] = (pa[2] as u64) >> 6; + p[1] |= (pa[3] as u64) << 23; + p[1] |= (pa[4] as u64) << 52; + p[2] = (pa[4] as u64) >> 12; + p[2] |= (pa[5] as u64) << 17; + p[2] |= (pa[6] as u64) << 46; + p[3] = (pa[6] as u64) >> 18; + p[3] |= (pa[7] as u64) << 11; + p[3] |= (pa[8] as u64) << 40; + p +} + +pub trait FpConstants: Send + Sync + 'static + Sized { + const MODULUS: B; + const MODULUS64: B64 = { + let mut modulus64 = [0u64; 9]; + let modulus = Self::MODULUS; + let mut i = 0; + while i < 9 { + modulus64[i] = modulus[i] as u64; + i += 1; + } + modulus64 + }; + + /// montgomery params + /// TODO: compute these + const R: B; // R = 2^261 mod modulus + const R2: B; // R^2 mod modulus + const MINV: u64; // -modulus^(-1) mod 2^29, as a u64 +} + +#[inline] +fn gte_modulus(x: &B) -> bool { + for i in (0..9).rev() { + // don't fix warning -- that makes it 15% slower! + #[allow(clippy::comparison_chain)] + if x[i] > FpC::MODULUS[i] { + return true; + } else if x[i] < FpC::MODULUS[i] { + return false; + } + } + true +} + +// TODO performance ideas to test: +// - unroll loops +// - introduce locals for a[i] instead of accessing memory multiple times +// - only do 1 carry pass at the end, by proving properties of greater-than on uncarried result +// - use cheaper, approximate greater-than check a[8] > Fp::MODULUS[8] +pub fn add_assign(x: &mut B, y: &B) { + let mut tmp: u32; + let mut carry: i32 = 0; + + for i in 0..9 { + tmp = x[i] + y[i] + (carry as u32); + carry = (tmp as i32) >> SHIFT; + x[i] = tmp & MASK; + } + + if gte_modulus::(x) { + carry = 0; + #[allow(clippy::needless_range_loop)] + for i in 0..9 { + tmp = x[i].wrapping_sub(FpC::MODULUS[i]) + (carry as u32); + carry = (tmp as i32) >> SHIFT; + x[i] = tmp & MASK; + } + } +} + +#[inline] +fn conditional_reduce(x: &mut B) { + if gte_modulus::(x) { + #[allow(clippy::needless_range_loop)] + for i in 0..9 { + x[i] = x[i].wrapping_sub(FpC::MODULUS[i]); + } + #[allow(clippy::needless_range_loop)] + for i in 1..9 { + x[i] += ((x[i - 1] as i32) >> SHIFT) as u32; + } + #[allow(clippy::needless_range_loop)] + for i in 0..8 { + x[i] &= MASK; + } + } +} + +/// Montgomery multiplication +pub fn mul_assign(x: &mut B, y: &B) { + // load y[i] into local u64s + // TODO make sure these are locals + let mut y_local = [0u64; 9]; + for i in 0..9 { + y_local[i] = y[i] as u64; + } + + // locals for result + let mut z = [0u64; 8]; + let mut tmp: u64; + + // main loop, without intermediate carries except for z0 + #[allow(clippy::needless_range_loop)] + for i in 0..9 { + let xi = x[i] as u64; + + // compute qi and carry z0 result to z1 before discarding z0 + tmp = (xi * y_local[0]) + z[0]; + let qi = ((tmp & MASK64) * FpC::MINV) & MASK64; + z[1] += (tmp + qi * FpC::MODULUS64[0]) >> SHIFT64; + + // compute zi and shift in one step + for j in 1..8 { + z[j - 1] = z[j] + (xi * y_local[j]) + (qi * FpC::MODULUS64[j]); + } + // for j=8 we save an addition since z[8] is never needed + z[7] = xi * y_local[8] + qi * FpC::MODULUS64[8]; + } + + // final carry pass, store result back into x + x[0] = (z[0] & MASK64) as u32; + for i in 1..8 { + x[i] = (((z[i - 1] >> SHIFT64) + z[i]) & MASK64) as u32; + } + x[8] = (z[7] >> SHIFT64) as u32; + + // at this point, x is guaranteed to be less than 2*MODULUS + // conditionally subtract the modulus to bring it back into the canonical range + conditional_reduce::(x); +} + +// implement FpBackend given FpConstants + +pub fn from_bigint_unsafe(x: BigInt<9>) -> Fp { + let mut r = x.0; + // convert to montgomery form + mul_assign::(&mut r, &FpC::R2); + Fp(BigInt(r), Default::default()) +} + +impl FpBackend<9> for FpC { + const MODULUS: BigInt<9> = BigInt(Self::MODULUS); + const ZERO: BigInt<9> = BigInt([0; 9]); + const ONE: BigInt<9> = BigInt(Self::R); + + fn add_assign(x: &mut Fp, y: &Fp) { + add_assign::(&mut x.0 .0, &y.0 .0); + } + + fn mul_assign(x: &mut Fp, y: &Fp) { + mul_assign::(&mut x.0 .0, &y.0 .0); + } + + fn from_bigint(x: BigInt<9>) -> Option> { + if gte_modulus::(&x.0) { + None + } else { + Some(from_bigint_unsafe(x)) + } + } + fn to_bigint(x: Fp) -> BigInt<9> { + let one = [1, 0, 0, 0, 0, 0, 0, 0, 0]; + let mut r = x.0 .0; + // convert back from montgomery form + mul_assign::(&mut r, &one); + BigInt(r) + } + + fn pack(x: Fp) -> Vec { + let x = Self::to_bigint(x).0; + let x64 = to_64x4(x); + let mut res = Vec::with_capacity(4); + for limb in x64.iter() { + res.push(*limb); + } + res + } +} diff --git a/curves/src/pasta/wasm_friendly/bigint32.rs b/curves/src/pasta/wasm_friendly/bigint32.rs new file mode 100644 index 0000000000..f7e2357795 --- /dev/null +++ b/curves/src/pasta/wasm_friendly/bigint32.rs @@ -0,0 +1,52 @@ +/** + * BigInt with 32-bit limbs + * + * Contains everything for wasm_fp which is unrelated to being a field + * + * Code is mostly copied from ark-ff::BigInt + */ +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +pub struct BigInt(pub [u32; N]); + +impl Default for BigInt { + fn default() -> Self { + Self([0u32; N]) + } +} + +impl CanonicalSerialize for BigInt { + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(writer, compress) + } + + fn serialized_size(&self, compress: Compress) -> usize { + self.0.serialized_size(compress) + } +} + +impl Valid for BigInt { + fn check(&self) -> Result<(), SerializationError> { + self.0.check() + } +} + +impl CanonicalDeserialize for BigInt { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Ok(BigInt::(<[u32; N]>::deserialize_with_mode( + reader, compress, validate, + )?)) + } +} diff --git a/curves/src/pasta/wasm_friendly/minimal_field.rs b/curves/src/pasta/wasm_friendly/minimal_field.rs new file mode 100644 index 0000000000..367b82b740 --- /dev/null +++ b/curves/src/pasta/wasm_friendly/minimal_field.rs @@ -0,0 +1,44 @@ +use ark_ff::{BitIteratorBE, One, Zero}; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use std::ops::{Add, AddAssign, Mul, MulAssign}; + +/** + * Minimal Field trait needed to implement Poseidon + */ +pub trait MinimalField: + 'static + + Copy + + Clone + + CanonicalSerialize + + CanonicalDeserialize + + Zero + + One + + for<'a> Add<&'a Self, Output = Self> + + for<'a> Mul<&'a Self, Output = Self> + + for<'a> AddAssign<&'a Self> + + for<'a> MulAssign<&'a Self> +{ + /// Squares `self` in place. + fn square_in_place(&mut self) -> &mut Self; + + /// Returns `self^exp`, where `exp` is an integer represented with `u64` limbs, + /// least significant limb first. + fn pow>(&self, exp: S) -> Self { + let mut res = Self::one(); + + for i in BitIteratorBE::without_leading_zeros(exp) { + res.square_in_place(); + + if i { + res *= self; + } + } + res + } +} + +impl MinimalField for F { + fn square_in_place(&mut self) -> &mut Self { + self.square_in_place() + } +} diff --git a/curves/src/pasta/wasm_friendly/mod.rs b/curves/src/pasta/wasm_friendly/mod.rs new file mode 100644 index 0000000000..15800465ec --- /dev/null +++ b/curves/src/pasta/wasm_friendly/mod.rs @@ -0,0 +1,12 @@ +pub mod bigint32; +pub use bigint32::BigInt; + +pub mod minimal_field; +pub use minimal_field::MinimalField; + +pub mod wasm_fp; +pub use wasm_fp::Fp; + +pub mod backend9; +pub mod pasta; +pub use pasta::Fp9; diff --git a/curves/src/pasta/wasm_friendly/pasta.rs b/curves/src/pasta/wasm_friendly/pasta.rs new file mode 100644 index 0000000000..f6cdb802b8 --- /dev/null +++ b/curves/src/pasta/wasm_friendly/pasta.rs @@ -0,0 +1,33 @@ +use super::{backend9, wasm_fp}; +use crate::pasta::Fp; +use ark_ff::PrimeField; + +pub struct Fp9Parameters; + +impl backend9::FpConstants for Fp9Parameters { + const MODULUS: [u32; 9] = [ + 0x1, 0x9698768, 0x133e46e6, 0xd31f812, 0x224, 0x0, 0x0, 0x0, 0x400000, + ]; + const R: [u32; 9] = [ + 0x1fffff81, 0x14a5d367, 0x141ad3c0, 0x1435eec5, 0x1ffeefef, 0x1fffffff, 0x1fffffff, + 0x1fffffff, 0x3fffff, + ]; + const R2: [u32; 9] = [ + 0x3b6a, 0x19c10910, 0x1a6a0188, 0x12a4fd88, 0x634b36d, 0x178792ba, 0x7797a99, 0x1dce5b8a, + 0x3506bd, + ]; + const MINV: u64 = 0x1fffffff; +} +pub type Fp9 = wasm_fp::Fp; + +impl Fp9 { + pub fn from_fp(fp: Fp) -> Self { + backend9::from_bigint_unsafe(super::BigInt(backend9::from_64x4(fp.into_bigint().0))) + } +} + +impl From for Fp9 { + fn from(fp: Fp) -> Self { + Fp9::from_fp(fp) + } +} diff --git a/curves/src/pasta/wasm_friendly/wasm_fp.rs b/curves/src/pasta/wasm_friendly/wasm_fp.rs new file mode 100644 index 0000000000..7c0d45d34a --- /dev/null +++ b/curves/src/pasta/wasm_friendly/wasm_fp.rs @@ -0,0 +1,236 @@ +/** + * MinimalField trait implementation `Fp` which only depends on an `FpBackend` trait + * + * Most of this code was copied over from ark_ff::Fp + */ +use crate::pasta::wasm_friendly::bigint32::BigInt; +use ark_ff::{One, Zero}; +use ark_serialize::{ + CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, + Write, +}; +use derivative::Derivative; +use num_bigint::BigUint; +use std::{ + marker::PhantomData, + ops::{Add, AddAssign, Mul, MulAssign}, +}; + +use super::minimal_field::MinimalField; + +pub trait FpBackend: Send + Sync + 'static + Sized { + const MODULUS: BigInt; + const ZERO: BigInt; + const ONE: BigInt; + + fn add_assign(a: &mut Fp, b: &Fp); + fn mul_assign(a: &mut Fp, b: &Fp); + + /// Construct a field element from an integer in the range + /// `0..(Self::MODULUS - 1)`. Returns `None` if the integer is outside + /// this range. + fn from_bigint(x: BigInt) -> Option>; + fn to_bigint(x: Fp) -> BigInt; + + fn pack(x: Fp) -> Vec; +} + +/// Represents an element of the prime field F_p, where `p == P::MODULUS`. +/// This type can represent elements in any field of size at most N * 64 bits. +#[derive(Derivative)] +#[derivative( + Default(bound = ""), + Hash(bound = ""), + Copy(bound = ""), + PartialEq(bound = ""), + Eq(bound = ""), + Debug(bound = "") +)] +pub struct Fp, const N: usize>( + pub BigInt, + #[derivative(Debug = "ignore")] + #[doc(hidden)] + pub PhantomData

, +); + +impl, const N: usize> Clone for Fp { + fn clone(&self) -> Self { + *self + } +} + +impl, const N: usize> Fp { + pub fn new(bigint: BigInt) -> Self { + Fp(bigint, Default::default()) + } + + #[inline] + pub fn from_bigint(r: BigInt) -> Option { + P::from_bigint(r) + } + #[inline] + pub fn into_bigint(self) -> BigInt { + P::to_bigint(self) + } + + pub fn to_bytes_le(self) -> Vec { + let chunks = P::pack(self).into_iter().map(|x| x.to_le_bytes()); + let mut bytes = Vec::with_capacity(chunks.len() * 8); + for chunk in chunks { + bytes.extend_from_slice(&chunk); + } + bytes + } +} + +// coerce into Fp from either BigInt or [u32; N] + +impl, const N: usize> From> for Fp { + fn from(val: BigInt) -> Self { + Fp::from_bigint(val).unwrap() + } +} + +impl, const N: usize> From<[u32; N]> for Fp { + fn from(val: [u32; N]) -> Self { + Fp::from_bigint(BigInt(val)).unwrap() + } +} + +// field + +impl, const N: usize> MinimalField for Fp { + fn square_in_place(&mut self) -> &mut Self { + // implemented with mul_assign for now + let self_copy = *self; + self.mul_assign(&self_copy); + self + } +} + +// add, zero + +impl, const N: usize> Zero for Fp { + #[inline] + fn zero() -> Self { + Fp::new(P::ZERO) + } + + #[inline] + fn is_zero(&self) -> bool { + *self == Self::zero() + } +} + +impl<'a, P: FpBackend, const N: usize> AddAssign<&'a Self> for Fp { + #[inline] + fn add_assign(&mut self, other: &Self) { + P::add_assign(self, other) + } +} +impl, const N: usize> Add for Fp { + type Output = Self; + + #[inline] + fn add(mut self, other: Self) -> Self { + self.add_assign(&other); + self + } +} +impl<'a, P: FpBackend, const N: usize> Add<&'a Fp> for Fp { + type Output = Self; + + #[inline] + fn add(mut self, other: &Self) -> Self { + self.add_assign(other); + self + } +} + +// mul, one + +impl, const N: usize> One for Fp { + #[inline] + fn one() -> Self { + Fp::new(P::ONE) + } + + #[inline] + fn is_one(&self) -> bool { + *self == Self::one() + } +} +impl<'a, P: FpBackend, const N: usize> MulAssign<&'a Self> for Fp { + #[inline] + fn mul_assign(&mut self, other: &Self) { + P::mul_assign(self, other) + } +} +impl, const N: usize> Mul for Fp { + type Output = Self; + + #[inline] + fn mul(mut self, other: Self) -> Self { + self.mul_assign(&other); + self + } +} +impl<'a, P: FpBackend, const N: usize> Mul<&'a Fp> for Fp { + type Output = Self; + + #[inline] + fn mul(mut self, other: &Self) -> Self { + self.mul_assign(other); + self + } +} + +// (de)serialization + +impl, const N: usize> CanonicalSerialize for Fp { + #[inline] + fn serialize_with_mode( + &self, + writer: W, + compress: Compress, + ) -> Result<(), SerializationError> { + self.0.serialize_with_mode(writer, compress) + } + + #[inline] + fn serialized_size(&self, compress: Compress) -> usize { + self.0.serialized_size(compress) + } +} + +impl, const N: usize> Valid for Fp { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } +} + +impl, const N: usize> CanonicalDeserialize for Fp { + fn deserialize_with_mode( + reader: R, + compress: Compress, + validate: Validate, + ) -> Result { + Self::from_bigint(BigInt::deserialize_with_mode(reader, compress, validate)?) + .ok_or(SerializationError::InvalidData) + } +} + +// display + +impl, const N: usize> From> for BigUint { + #[inline] + fn from(val: Fp) -> BigUint { + BigUint::from_bytes_le(&val.to_bytes_le()) + } +} + +impl, const N: usize> std::fmt::Display for Fp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + BigUint::from(*self).fmt(f) + } +} diff --git a/kimchi/benches/proof_criterion.rs b/kimchi/benches/proof_criterion.rs index 8a1923e48f..b73987919c 100644 --- a/kimchi/benches/proof_criterion.rs +++ b/kimchi/benches/proof_criterion.rs @@ -3,17 +3,18 @@ use kimchi::bench::BenchmarkCtx; pub fn bench_proof_creation(c: &mut Criterion) { let mut group = c.benchmark_group("Proof creation"); + //group.significance_level(0.1).sample_size(50); group.sample_size(10).sampling_mode(SamplingMode::Flat); // for slow benchmarks let ctx = BenchmarkCtx::new(10); group.bench_function( - format!("proof creation (SRS size 2^{})", ctx.srs_size()), + format!("proof creation small (SRS size 2^{})", ctx.srs_size()), |b| b.iter(|| black_box(ctx.create_proof())), ); let ctx = BenchmarkCtx::new(14); group.bench_function( - format!("proof creation (SRS size 2^{})", ctx.srs_size()), + format!("proof creation big (SRS size 2^{})", ctx.srs_size()), |b| b.iter(|| black_box(ctx.create_proof())), ); @@ -21,7 +22,7 @@ pub fn bench_proof_creation(c: &mut Criterion) { group.sample_size(100).sampling_mode(SamplingMode::Auto); group.bench_function( - format!("proof verification (SRS size 2^{})", ctx.srs_size()), + format!("proof verification big (SRS size 2^{})", ctx.srs_size()), |b| b.iter(|| ctx.batch_verification(black_box(&vec![proof_and_public.clone()]))), ); } diff --git a/poseidon/benches/poseidon_bench.rs b/poseidon/benches/poseidon_bench.rs index 81a14d20fb..74c5040095 100644 --- a/poseidon/benches/poseidon_bench.rs +++ b/poseidon/benches/poseidon_bench.rs @@ -1,10 +1,12 @@ +use ark_ff::Zero; use criterion::{criterion_group, criterion_main, Criterion}; -use mina_curves::pasta::Fp; +use mina_curves::pasta::{wasm_friendly::Fp9, Fp}; use mina_poseidon::{ constants::PlonkSpongeConstantsKimchi, pasta::fp_kimchi as SpongeParametersKimchi, - poseidon::{ArithmeticSponge as Poseidon, Sponge}, + poseidon::{ArithmeticSponge as Poseidon, ArithmeticSpongeParams, Sponge}, }; +use once_cell::sync::Lazy; pub fn bench_poseidon_kimchi(c: &mut Criterion) { let mut group = c.benchmark_group("Poseidon"); @@ -17,6 +19,23 @@ pub fn bench_poseidon_kimchi(c: &mut Criterion) { SpongeParametersKimchi::static_params(), ); + // poseidon.absorb(&[Fp::zero()]); + // println!("{}", poseidon.squeeze()); + + b.iter(|| { + poseidon.absorb(&[hash]); + hash = poseidon.squeeze(); + }) + }); + + // same as above but with Fp9 + group.bench_function("poseidon_hash_kimchi_fp9", |b| { + let mut hash: Fp9 = Fp9::zero(); + let mut poseidon = Poseidon::::new(fp9_static_params()); + + // poseidon.absorb(&[Fp9::zero()]); + // println!("{}", poseidon.squeeze()); + b.iter(|| { poseidon.absorb(&[hash]); hash = poseidon.squeeze(); @@ -26,5 +45,192 @@ pub fn bench_poseidon_kimchi(c: &mut Criterion) { group.finish(); } -criterion_group!(benches, bench_poseidon_kimchi); +pub fn bench_conversions(c: &mut Criterion) { + let mut group = c.benchmark_group("Conversions"); + + group.bench_function("Conversion: fp_to_fp9", |b| { + b.iter_batched( + || { + let x: Fp = rand::random(); + x + }, + |x| { + let z: Fp9 = x.into(); + z + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.bench_function("Conversion: fp_to_fp9, 2^16 elements", |b| { + b.iter_batched( + || (0..65536).map(|_| rand::random()).collect(), + |hashes_fp: Vec| { + let mut hashes_fp9: Vec = Vec::with_capacity(65536); + for h in hashes_fp.clone().into_iter() { + hashes_fp9.push(h.into()); + } + hashes_fp9 + }, + criterion::BatchSize::SmallInput, + ); + }); +} + +pub fn bench_basic_ops(c: &mut Criterion) { + let mut group = c.benchmark_group("Basic ops"); + + group.bench_function("Native multiplication in Fp (single)", |b| { + b.iter_batched( + || { + let x: Fp = rand::random(); + let y: Fp = rand::random(); + (x, y) + }, + |(x, y)| { + let z: Fp = x * y; + z + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.bench_function("Multiplication in Fp9 (single)", |b| { + b.iter_batched( + || { + let x: Fp = rand::random(); + let y: Fp = rand::random(); + let x_fp9: Fp9 = x.into(); + let y_fp9: Fp9 = y.into(); + (x_fp9, y_fp9) + }, + |(x_fp9, y_fp9)| { + let z_fp9: Fp9 = x_fp9 * y_fp9; + z_fp9 + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.bench_function("Multiplication in Fp9 with a conversion (single)", |b| { + b.iter_batched( + || { + let x: Fp = rand::random(); + let y: Fp = rand::random(); + (x, y) + }, + |(x, y)| { + let x_fp9: Fp9 = From::from(x); + let y_fp9: Fp9 = From::from(y); + let z_fp9: Fp9 = x_fp9 * y_fp9; + z_fp9 + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.bench_function("Native multiplication in Fp (double)", |b| { + b.iter_batched( + || { + let x: Fp = rand::random(); + let y: Fp = rand::random(); + (x, y) + }, + |(x, y)| { + let z: Fp = x * y; + let z: Fp = z * x; + z + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.bench_function("Multiplication in Fp9 with a conversion (double)", |b| { + b.iter_batched( + || { + let x: Fp = rand::random(); + let y: Fp = rand::random(); + (x, y) + }, + |(x, y)| { + let x_fp9: Fp9 = From::from(x); + let y_fp9: Fp9 = From::from(y); + let z_fp9: Fp9 = x_fp9 * y_fp9; + let z_fp9: Fp9 = z_fp9 * x_fp9; + z_fp9 + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.bench_function("Native multiplication in Fp (4 muls)", |b| { + b.iter_batched( + || { + let x: Fp = rand::random(); + let y: Fp = rand::random(); + (x, y) + }, + |(x, y)| { + let z: Fp = x * y; + let z: Fp = z * x; + let z: Fp = z * y; + let z: Fp = z * x; + z + }, + criterion::BatchSize::SmallInput, + ); + }); + + group.bench_function("Multiplication in Fp9 with a conversion (4 muls)", |b| { + b.iter_batched( + || { + let x: Fp = rand::random(); + let y: Fp = rand::random(); + (x, y) + }, + |(x, y)| { + let x_fp9: Fp9 = From::from(x); + let y_fp9: Fp9 = From::from(y); + let z_fp9: Fp9 = x_fp9 * y_fp9; + let z_fp9: Fp9 = z_fp9 * x_fp9; + let z_fp9: Fp9 = z_fp9 * y_fp9; + let z_fp9: Fp9 = z_fp9 * x_fp9; + z_fp9 + }, + criterion::BatchSize::SmallInput, + ); + }); +} + +criterion_group!( + benches, + bench_poseidon_kimchi, + bench_conversions, + bench_basic_ops +); criterion_main!(benches); + +// sponge params for Fp9 + +fn fp9_sponge_params() -> ArithmeticSpongeParams { + let params = SpongeParametersKimchi::params(); + + // leverage .into() to convert from Fp to Fp9 + ArithmeticSpongeParams:: { + round_constants: params + .round_constants + .into_iter() + .map(|x| x.into_iter().map(Fp9::from).collect()) + .collect(), + mds: params + .mds + .into_iter() + .map(|x| x.into_iter().map(Fp9::from).collect()) + .collect(), + } +} + +fn fp9_static_params() -> &'static ArithmeticSpongeParams { + static PARAMS: Lazy> = Lazy::new(fp9_sponge_params); + &PARAMS +} diff --git a/poseidon/src/permutation.rs b/poseidon/src/permutation.rs index b3d58776a5..468e2e1f0f 100644 --- a/poseidon/src/permutation.rs +++ b/poseidon/src/permutation.rs @@ -1,12 +1,13 @@ //! The permutation module contains the function implementing the permutation used in Poseidon +use mina_curves::pasta::wasm_friendly::minimal_field::MinimalField; + use crate::{ constants::SpongeConstants, poseidon::{sbox, ArithmeticSpongeParams}, }; -use ark_ff::Field; -fn apply_mds_matrix( +fn apply_mds_matrix( params: &ArithmeticSpongeParams, state: &[F], ) -> Vec { @@ -30,7 +31,7 @@ fn apply_mds_matrix( } } -pub fn full_round( +pub fn full_round( params: &ArithmeticSpongeParams, state: &mut Vec, r: usize, @@ -44,7 +45,7 @@ pub fn full_round( } } -pub fn half_rounds( +pub fn half_rounds( params: &ArithmeticSpongeParams, state: &mut [F], ) { @@ -84,7 +85,7 @@ pub fn half_rounds( } } -pub fn poseidon_block_cipher( +pub fn poseidon_block_cipher( params: &ArithmeticSpongeParams, state: &mut Vec, ) { diff --git a/poseidon/src/poseidon.rs b/poseidon/src/poseidon.rs index 2ce1d1f3d3..7df61ae5fe 100644 --- a/poseidon/src/poseidon.rs +++ b/poseidon/src/poseidon.rs @@ -4,14 +4,14 @@ use crate::{ constants::SpongeConstants, permutation::{full_round, poseidon_block_cipher}, }; -use ark_ff::Field; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use mina_curves::pasta::wasm_friendly::MinimalField; use serde::{Deserialize, Serialize}; use serde_with::serde_as; /// Cryptographic sponge interface - for hashing an arbitrary amount of /// data into one or more field elements -pub trait Sponge { +pub trait Sponge { /// Create a new cryptographic sponge using arithmetic sponge `params` fn new(params: &'static ArithmeticSpongeParams) -> Self; @@ -25,7 +25,7 @@ pub trait Sponge { fn reset(&mut self); } -pub fn sbox(x: F) -> F { +pub fn sbox(x: F) -> F { x.pow([SC::PERM_SBOX as u64]) } @@ -37,7 +37,7 @@ pub enum SpongeState { #[serde_as] #[derive(Clone, Serialize, Deserialize, Default, Debug)] -pub struct ArithmeticSpongeParams { +pub struct ArithmeticSpongeParams { #[serde_as(as = "Vec>")] pub round_constants: Vec>, #[serde_as(as = "Vec>")] @@ -45,7 +45,7 @@ pub struct ArithmeticSpongeParams { +pub struct ArithmeticSponge { pub sponge_state: SpongeState, rate: usize, // TODO(mimoo: an array enforcing the width is better no? or at least an assert somewhere) @@ -54,7 +54,7 @@ pub struct ArithmeticSponge { pub constants: std::marker::PhantomData, } -impl ArithmeticSponge { +impl ArithmeticSponge { pub fn full_round(&mut self, r: usize) { full_round::(self.params, &mut self.state, r); } @@ -64,7 +64,7 @@ impl ArithmeticSponge { } } -impl Sponge for ArithmeticSponge { +impl Sponge for ArithmeticSponge { fn new(params: &'static ArithmeticSpongeParams) -> ArithmeticSponge { let capacity = SC::SPONGE_CAPACITY; let rate = SC::SPONGE_RATE;