Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mont 64 #1

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions math/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ harness = false
name = "polynom"
harness = false


Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
[features]
concurrent = ["utils/concurrent", "std"]
default = ["std"]
Expand Down
2 changes: 1 addition & 1 deletion math/src/field/extensions/quadratic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ impl<B: ExtensibleField<2>> Deserializable for QuadExtension<B> {
#[cfg(test)]
mod tests {
use super::{DeserializationError, FieldElement, QuadExtension, Vec};
use crate::field::f128::BaseElement;
use crate::field::f64::BaseElement;
use rand_utils::rand_value;

// BASIC ALGEBRA
Expand Down
158 changes: 87 additions & 71 deletions math/src/field/f64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

//! An implementation of a 64-bit STARK-friendly prime field with modulus $2^{64} - 2^{32} + 1$.
//! An implementation of a 64-bit STARK-friendly prime field with modulus $2^{64} - 2^{32} + 1$ using Montgomery representation.
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
//! Our implementation follows https://eprint.iacr.org/2022/274.pdf
//!
//! This field supports very fast modular arithmetic and has a number of other attractive
//! properties, including:
Expand Down Expand Up @@ -36,8 +37,11 @@ mod tests;
// Field modulus = 2^64 - 2^32 + 1
const M: u64 = 0xFFFFFFFF00000001;

// Epsilon = 2^32 - 1;
const E: u64 = 0xFFFFFFFF;
/// 2^128 mod M; this is used for conversion of elements into Montgomery representation.
const R2: u64 = 0xFFFFFFFE00000001;

// (p+1)/2
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
pub const MOD_2: u64 = (0xFFFFFFFF00000001 + 1u64) >> 1;

// 2^32 root of unity
const G: u64 = 1753635133440165772;
Expand All @@ -53,12 +57,36 @@ const ELEMENT_BYTES: usize = core::mem::size_of::<u64>();
/// Internal values are stored in the range [0, 2^64). The backing type is `u64`.
#[derive(Copy, Clone, Debug, Default)]
pub struct BaseElement(u64);

impl BaseElement {
/// Creates a new field element from the provided `value`. If the value is greater than or
/// equal to the field modulus, modular reduction is silently performed.
pub const fn new(value: u64) -> Self {
Self(value % M)
/// Creates a new field element from the provided `value`; the value is converted into
/// Montgomery representation.
pub const fn new(value: u64) -> BaseElement {
BaseElement(mont_red_cst((value as u128) * (R2 as u128)))
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
}
/// Gets the inner value that might not be canonical

Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
pub const fn inner(self: &Self) -> u64 {
return self.0;
}

/// Multiple squarings in BaseField: return x^(2^n)
pub fn msquare(self, n: u32) -> Self {
let mut x = self;
for _ in 0..n {
x = x.square();
}
x
}
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved

/// Test of equality between two BaseField elements; return value is
/// 0xFFFFFFFFFFFFFFFF if the two values are equal, or 0 otherwise.
#[inline(always)]
pub const fn equals(self, rhs: Self) -> u64 {
// Since internal representation is canonical, we can simply
// do a xor between the two operands, and then use the same
// expression as iszero().
let t = self.0 ^ rhs.0;
!((((t | t.wrapping_neg()) as i64) >> 63) as u64)
}
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
}

Expand All @@ -72,7 +100,7 @@ impl FieldElement for BaseElement {
const ELEMENT_BYTES: usize = ELEMENT_BYTES;
const IS_CANONICAL: bool = false;

#[inline]
#[inline(always)]
fn exp(self, power: Self::PositiveInteger) -> Self {
let mut b = self;

Expand Down Expand Up @@ -171,7 +199,6 @@ impl FieldElement for BaseElement {
unsafe { Vec::from_raw_parts(p as *mut Self, len, cap) }
}

#[inline]
fn as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
elements
}
Expand Down Expand Up @@ -205,13 +232,7 @@ impl StarkField for BaseElement {

#[inline]
fn as_int(&self) -> Self::PositiveInteger {
// since the internal value of the element can be in [0, 2^64) range, we do an extra check
// here to convert it to the canonical form
if self.0 >= M {
self.0 - M
} else {
self.0
}
mont_red_cst(self.0 as u128)
}
}

Expand All @@ -235,9 +256,7 @@ impl Display for BaseElement {
impl PartialEq for BaseElement {
#[inline]
fn eq(&self, other: &Self) -> bool {
// since either of the elements can be in [0, 2^64) range, we first convert them to the
// canonical form to ensure that they are in [0, M) range and then compare them.
self.as_int() == other.as_int()
Self::equals(*self, *other) == 0xFFFFFFFFFFFFFFFF
}
}

Expand All @@ -249,11 +268,14 @@ impl Eq for BaseElement {}
impl Add for BaseElement {
type Output = Self;

/// Addition in BaseField
#[inline]
#[allow(clippy::suspicious_arithmetic_impl)]
fn add(self, rhs: Self) -> Self {
let (result, over) = self.0.overflowing_add(rhs.as_int());
Self(result.wrapping_sub(M * (over as u64)))
// We compute a + b = a - (p - b).
let (x1, c1) = self.0.overflowing_sub(M - rhs.0);
let adj = 0u32.wrapping_sub(c1 as u32);
BaseElement(x1.wrapping_sub(adj as u64))
}
}

Expand All @@ -270,8 +292,10 @@ impl Sub for BaseElement {
#[inline]
#[allow(clippy::suspicious_arithmetic_impl)]
fn sub(self, rhs: Self) -> Self {
let (result, under) = self.0.overflowing_sub(rhs.as_int());
Self(result.wrapping_add(M * (under as u64)))
// See reference above for more details.
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
let (x1, c1) = self.0.overflowing_sub(rhs.0);
let adj = 0u32.wrapping_sub(c1 as u32);
BaseElement(x1.wrapping_sub(adj as u64))
}
}

Expand All @@ -287,8 +311,7 @@ impl Mul for BaseElement {

#[inline]
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
fn mul(self, rhs: Self) -> Self {
let z = (self.0 as u128) * (rhs.0 as u128);
Self(mod_reduce(z))
BaseElement(mont_red_cst((self.0 as u128) * (rhs.0 as u128)))
}
}

Expand All @@ -305,7 +328,7 @@ impl Div for BaseElement {
#[inline]
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self {
self * rhs.inv()
self * Self::inv(rhs)
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
}
}
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved

Expand All @@ -321,12 +344,7 @@ impl Neg for BaseElement {

#[inline]
fn neg(self) -> Self {
let v = self.as_int();
if v == 0 {
Self::ZERO
} else {
Self(M - v)
}
BaseElement::ZERO - self
}
}

Expand Down Expand Up @@ -421,39 +439,40 @@ impl ExtensibleField<3> for BaseElement {
// ================================================================================================

impl From<u128> for BaseElement {
/// Converts a 128-bit value into a field element. If the value is greater than or equal to
/// the field modulus, modular reduction is silently performed.
fn from(value: u128) -> Self {
Self(mod_reduce(value))
/// Converts a 128-bit value into a field element.
fn from(x: u128) -> Self {
//const R3: u128 = 1 (= 2^192 mod M );// thus we get that mont_red_var((mont_red_var(x) as u128) * R3) becomes
//BaseElement(mont_red_var(mont_red_var(x) as u128)) // Variable time implementation
BaseElement(mont_red_cst(mont_red_cst(x) as u128)) // Constant time implementation
}
}

impl From<u64> for BaseElement {
/// Converts a 64-bit value into a field element. If the value is greater than or equal to
/// the field modulus, modular reduction is silently performed.
fn from(value: u64) -> Self {
Self::new(value)
BaseElement::new(value)
}
}

impl From<u32> for BaseElement {
/// Converts a 32-bit value into a field element.
fn from(value: u32) -> Self {
Self::new(value as u64)
BaseElement::new(value as u64)
}
}

impl From<u16> for BaseElement {
/// Converts a 16-bit value into a field element.
fn from(value: u16) -> Self {
Self::new(value as u64)
BaseElement::new(value as u64)
}
}

impl From<u8> for BaseElement {
/// Converts an 8-bit value into a field element.
fn from(value: u8) -> Self {
Self::new(value as u64)
BaseElement::new(value as u64)
}
}

Expand All @@ -464,7 +483,7 @@ impl From<[u8; 8]> for BaseElement {
/// performed.
fn from(bytes: [u8; 8]) -> Self {
let value = u64::from_le_bytes(bytes);
Self::new(value)
BaseElement::new(value)
}
}

Expand Down Expand Up @@ -534,35 +553,6 @@ impl Deserializable for BaseElement {
}
}

// HELPER FUNCTIONS
// ================================================================================================

/// Reduces a 128-bit value by M such that the output is in [0, 2^64) range.
///
/// Adapted from: <https://github.com/mir-protocol/plonky2/blob/main/src/field/goldilocks_field.rs>
#[inline(always)]
fn mod_reduce(x: u128) -> u64 {
// assume x consists of four 32-bit values: a, b, c, d such that a contains 32 least
// significant bits and d contains 32 most significant bits. we break x into corresponding
// values as shown below
let ab = x as u64;
let cd = (x >> 64) as u64;
let c = (cd as u32) as u64;
let d = cd >> 32;

// compute ab - d; because d may be greater than ab we need to handle potential underflow
let (tmp0, under) = ab.overflowing_sub(d);
let tmp0 = tmp0.wrapping_sub(E * (under as u64));

// compute c * 2^32 - c; this is guaranteed not to underflow
let tmp1 = (c << 32) - c;

// add temp values and return the result; because each of the temp may be up to 64 bits,
// we need to handle potential overflow
let (result, over) = tmp0.overflowing_add(tmp1);
result.wrapping_add(E * (over as u64))
}

/// Squares the base N number of times and multiplies the result by the tail value.
#[inline(always)]
fn exp_acc<const N: usize>(base: BaseElement, tail: BaseElement) -> BaseElement {
Expand All @@ -572,3 +562,29 @@ fn exp_acc<const N: usize>(base: BaseElement, tail: BaseElement) -> BaseElement
}
result * tail
}
/// Montgomery reduction (variable time)
#[inline(always)]
pub fn mont_red_var(x: u128) -> u64 {
const NPRIME: u64 = 4294967297;
let q = (((x as u64) as u128) * (NPRIME as u128)) as u64;
let m = (q as u128) * (M as u128);
let y = (((x as i128).wrapping_sub(m as i128)) >> 64) as i64;
if x < m {
return (y + (M as i64)) as u64;
} else {
return y as u64;
};
}
/// Montgomery reduction (constant time)
#[inline(always)]
pub const fn mont_red_cst(x: u128) -> u64 {
// See reference above for a description of the following implementation.
let xl = x as u64;
let xh = (x >> 64) as u64;
let (a, e) = xl.overflowing_add(xl << 32);

let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64);

let (r, c) = xh.overflowing_sub(b);
r.wrapping_sub(0u32.wrapping_sub(c as u32) as u64)
}
18 changes: 10 additions & 8 deletions math/src/field/f64/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
// LICENSE file in the root directory of this source tree.

use super::{
AsBytes, BaseElement, DeserializationError, FieldElement, Serializable, StarkField, E, M,
//E,
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
AsBytes, BaseElement, DeserializationError, FieldElement, Serializable, StarkField, M,
};
use crate::field::{CubeExtension, ExtensionOf, QuadExtension};
use core::convert::TryFrom;
Expand Down Expand Up @@ -32,10 +33,10 @@ fn add() {
assert_eq!(BaseElement::ZERO, t + BaseElement::ONE);
assert_eq!(BaseElement::ONE, t + BaseElement::new(2));

// test non-canonical representation
let a = BaseElement::new(M - 1) + BaseElement::new(E);
let expected = ((((M - 1 + E) as u128) * 2) % (M as u128)) as u64;
assert_eq!(expected, (a + a).as_int());
//// test non-canonical representation
//let a = BaseElement::new(M - 1) + BaseElement::new(E);
//let expected = ((((M - 1 + E) as u128) * 2) % (M as u128)) as u64;
//assert_eq!(expected, (a + a).as_int());
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
}

#[test]
Expand Down Expand Up @@ -71,7 +72,7 @@ fn mul() {
assert_eq!(BaseElement::ZERO, r * BaseElement::ZERO);
assert_eq!(r, r * BaseElement::ONE);

// test multiplication within bounds
// test multifield::extensions::cubic::tests::bytes_as_elementsplication within bounds
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(
BaseElement::from(15u8),
BaseElement::from(5u8) * BaseElement::from(3u8)
Expand Down Expand Up @@ -113,6 +114,7 @@ fn inv() {
assert_eq!(BaseElement::ZERO, BaseElement::inv(BaseElement::ZERO));
}


#[test]
fn element_as_int() {
let v = u64::MAX;
Expand All @@ -131,8 +133,8 @@ fn equals() {
assert_eq!(a.to_bytes(), b.to_bytes());

// but their internal representation is not
assert_ne!(a.0, b.0);
assert_ne!(a.as_bytes(), b.as_bytes());
//assert_ne!(a.0, b.0);
//assert_ne!(a.as_bytes(), b.as_bytes());
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
}

// ROOTS OF UNITY
Expand Down