Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mont 64 #1

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions math/src/field/extensions/cubic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ impl<B: ExtensibleField<3>> Deserializable for CubeExtension<B> {

#[cfg(test)]
mod tests {
use super::{CubeExtension, DeserializationError, FieldElement, Vec};
use super::{CubeExtension, DeserializationError, FieldElement};
use crate::field::f64::BaseElement;
use rand_utils::rand_value;

Expand Down Expand Up @@ -400,10 +400,13 @@ mod tests {
),
];

let expected: Vec<u8> = vec![
1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0,
0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0,
];
let mut expected = vec![];
expected.extend_from_slice(&source[0].0.inner().to_le_bytes());
expected.extend_from_slice(&source[0].1.inner().to_le_bytes());
expected.extend_from_slice(&source[0].2.inner().to_le_bytes());
expected.extend_from_slice(&source[1].0.inner().to_le_bytes());
expected.extend_from_slice(&source[1].1.inner().to_le_bytes());
expected.extend_from_slice(&source[1].2.inner().to_le_bytes());

assert_eq!(
expected,
Expand All @@ -413,12 +416,7 @@ mod tests {

#[test]
fn bytes_as_elements() {
let bytes: Vec<u8> = vec![
1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0,
0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7,
];

let expected = vec![
let elements = vec![
CubeExtension(
BaseElement::new(1),
BaseElement::new(2),
Expand All @@ -431,9 +429,18 @@ mod tests {
),
];

let mut bytes = vec![];
bytes.extend_from_slice(&elements[0].0.inner().to_le_bytes());
bytes.extend_from_slice(&elements[0].1.inner().to_le_bytes());
bytes.extend_from_slice(&elements[0].2.inner().to_le_bytes());
bytes.extend_from_slice(&elements[1].0.inner().to_le_bytes());
bytes.extend_from_slice(&elements[1].1.inner().to_le_bytes());
bytes.extend_from_slice(&elements[1].2.inner().to_le_bytes());
bytes.extend_from_slice(&BaseElement::new(5).inner().to_le_bytes());

let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes[..48]) };
assert!(result.is_ok());
assert_eq!(expected, result.unwrap());
assert_eq!(elements, result.unwrap());

let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes) };
assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
Expand Down
33 changes: 17 additions & 16 deletions math/src/field/extensions/quadratic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ impl<B: ExtensibleField<2>> Deserializable for QuadExtension<B> {

#[cfg(test)]
mod tests {
use super::{DeserializationError, FieldElement, QuadExtension, Vec};
use crate::field::f128::BaseElement;
use super::{DeserializationError, FieldElement, QuadExtension};
use crate::field::f64::BaseElement;
use rand_utils::rand_value;

// BASIC ALGEBRA
Expand Down Expand Up @@ -381,11 +381,11 @@ mod tests {
QuadExtension(BaseElement::new(3), BaseElement::new(4)),
];

let expected: Vec<u8> = vec![
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0,
];
let mut expected = vec![];
expected.extend_from_slice(&source[0].0.inner().to_le_bytes());
expected.extend_from_slice(&source[0].1.inner().to_le_bytes());
expected.extend_from_slice(&source[1].0.inner().to_le_bytes());
expected.extend_from_slice(&source[1].1.inner().to_le_bytes());

assert_eq!(
expected,
Expand All @@ -395,20 +395,21 @@ mod tests {

#[test]
fn bytes_as_elements() {
let bytes: Vec<u8> = vec![
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 5,
];

let expected = vec![
let elements = vec![
QuadExtension(BaseElement::new(1), BaseElement::new(2)),
QuadExtension(BaseElement::new(3), BaseElement::new(4)),
];

let result = unsafe { QuadExtension::<BaseElement>::bytes_as_elements(&bytes[..64]) };
let mut bytes = vec![];
bytes.extend_from_slice(&elements[0].0.inner().to_le_bytes());
bytes.extend_from_slice(&elements[0].1.inner().to_le_bytes());
bytes.extend_from_slice(&elements[1].0.inner().to_le_bytes());
bytes.extend_from_slice(&elements[1].1.inner().to_le_bytes());
bytes.extend_from_slice(&BaseElement::new(5).inner().to_le_bytes());

let result = unsafe { QuadExtension::<BaseElement>::bytes_as_elements(&bytes[..32]) };
assert!(result.is_ok());
assert_eq!(expected, result.unwrap());
assert_eq!(elements, result.unwrap());

let result = unsafe { QuadExtension::<BaseElement>::bytes_as_elements(&bytes) };
assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
Expand Down
153 changes: 77 additions & 76 deletions math/src/field/f64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.

//! An implementation of a 64-bit STARK-friendly prime field with modulus $2^{64} - 2^{32} + 1$.
//! An implementation of a 64-bit STARK-friendly prime field with modulus $2^{64} - 2^{32} + 1$
//! using Montgomery representation.
//! Our implementation follows https://eprint.iacr.org/2022/274.pdf and is constant-time except
//! for the mont_red_var() function.
//!
//! This field supports very fast modular arithmetic and has a number of other attractive
//! properties, including:
Expand Down Expand Up @@ -33,13 +36,13 @@ mod tests;
// CONSTANTS
// ================================================================================================

// Field modulus = 2^64 - 2^32 + 1
/// Field modulus = 2^64 - 2^32 + 1
const M: u64 = 0xFFFFFFFF00000001;

// Epsilon = 2^32 - 1;
const E: u64 = 0xFFFFFFFF;
/// 2^128 mod M; this is used for conversion of elements into Montgomery representation.
const R2: u64 = 0xFFFFFFFE00000001;

// 2^32 root of unity
/// 2^32 root of unity
const G: u64 = 1753635133440165772;

/// Number of bytes needed to represent field element
Expand All @@ -48,17 +51,21 @@ const ELEMENT_BYTES: usize = core::mem::size_of::<u64>();
// FIELD ELEMENT
// ================================================================================================

/// Represents base field element in the field.
/// Represents base field element in the field
///
/// Internal values are stored in the range [0, 2^64). The backing type is `u64`.
#[derive(Copy, Clone, Debug, Default)]
pub struct BaseElement(u64);

impl BaseElement {
/// Creates a new field element from the provided `value`. If the value is greater than or
/// equal to the field modulus, modular reduction is silently performed.
pub const fn new(value: u64) -> Self {
Self(value % M)
/// Creates a new field element from the provided `value`; the value is converted into
/// Montgomery representation.
pub const fn new(value: u64) -> BaseElement {
Self(mont_red_cst((value as u128) * (R2 as u128)))
}

/// Returns the non-canonical u64 inner value.
pub const fn inner(&self) -> u64{
self.0
}
}

Expand All @@ -72,7 +79,7 @@ impl FieldElement for BaseElement {
const ELEMENT_BYTES: usize = ELEMENT_BYTES;
const IS_CANONICAL: bool = false;

#[inline]
#[inline(always)]
fn exp(self, power: Self::PositiveInteger) -> Self {
let mut b = self;

Expand Down Expand Up @@ -171,7 +178,6 @@ impl FieldElement for BaseElement {
unsafe { Vec::from_raw_parts(p as *mut Self, len, cap) }
}

#[inline]
fn as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
elements
}
Expand Down Expand Up @@ -205,13 +211,7 @@ impl StarkField for BaseElement {

#[inline]
fn as_int(&self) -> Self::PositiveInteger {
// since the internal value of the element can be in [0, 2^64) range, we do an extra check
// here to convert it to the canonical form
if self.0 >= M {
self.0 - M
} else {
self.0
}
mont_red_cst(self.0 as u128)
}
}

Expand All @@ -235,9 +235,7 @@ impl Display for BaseElement {
impl PartialEq for BaseElement {
#[inline]
fn eq(&self, other: &Self) -> bool {
// since either of the elements can be in [0, 2^64) range, we first convert them to the
// canonical form to ensure that they are in [0, M) range and then compare them.
self.as_int() == other.as_int()
equals(self.0, other.0) == 0xFFFFFFFFFFFFFFFF
}
}

Expand All @@ -249,11 +247,14 @@ impl Eq for BaseElement {}
impl Add for BaseElement {
type Output = Self;

/// Addition in BaseField
#[inline]
#[allow(clippy::suspicious_arithmetic_impl)]
fn add(self, rhs: Self) -> Self {
let (result, over) = self.0.overflowing_add(rhs.as_int());
Self(result.wrapping_sub(M * (over as u64)))
// We compute a + b = a - (p - b).
let (x1, c1) = self.0.overflowing_sub(M - rhs.0);
let adj = 0u32.wrapping_sub(c1 as u32);
Self(x1.wrapping_sub(adj as u64))
}
}

Expand All @@ -270,8 +271,9 @@ impl Sub for BaseElement {
#[inline]
#[allow(clippy::suspicious_arithmetic_impl)]
fn sub(self, rhs: Self) -> Self {
let (result, under) = self.0.overflowing_sub(rhs.as_int());
Self(result.wrapping_add(M * (under as u64)))
let (x1, c1) = self.0.overflowing_sub(rhs.0);
let adj = 0u32.wrapping_sub(c1 as u32);
Self(x1.wrapping_sub(adj as u64))
}
}

Expand All @@ -287,8 +289,7 @@ impl Mul for BaseElement {

#[inline]
Al-Kindi-0 marked this conversation as resolved.
Show resolved Hide resolved
fn mul(self, rhs: Self) -> Self {
let z = (self.0 as u128) * (rhs.0 as u128);
Self(mod_reduce(z))
Self(mont_red_cst((self.0 as u128) * (rhs.0 as u128)))
}
}

Expand Down Expand Up @@ -321,12 +322,7 @@ impl Neg for BaseElement {

#[inline]
fn neg(self) -> Self {
let v = self.as_int();
if v == 0 {
Self::ZERO
} else {
Self(M - v)
}
Self::ZERO - self
}
}

Expand Down Expand Up @@ -407,12 +403,9 @@ impl ExtensibleField<3> for BaseElement {
fn frobenius(x: [Self; 3]) -> [Self; 3] {
// coefficients were computed using SageMath
[
x[0] + BaseElement::new(10615703402128488253) * x[1]
+ BaseElement::new(6700183068485440220) * x[2],
BaseElement::new(10050274602728160328) * x[1]
+ BaseElement::new(14531223735771536287) * x[2],
BaseElement::new(11746561000929144102) * x[1]
+ BaseElement::new(8396469466686423992) * x[2],
x[0] + Self::new(10615703402128488253) * x[1] + Self::new(6700183068485440220) * x[2],
Self::new(10050274602728160328) * x[1] + Self::new(14531223735771536287) * x[2],
Self::new(11746561000929144102) * x[1] + Self::new(8396469466686423992) * x[2],
]
}
}
Expand All @@ -421,10 +414,11 @@ impl ExtensibleField<3> for BaseElement {
// ================================================================================================

impl From<u128> for BaseElement {
/// Converts a 128-bit value into a field element. If the value is greater than or equal to
/// the field modulus, modular reduction is silently performed.
fn from(value: u128) -> Self {
Self(mod_reduce(value))
/// Converts a 128-bit value into a field element.
fn from(x: u128) -> Self {
//const R3: u128 = 1 (= 2^192 mod M );// thus we get that mont_red_var((mont_red_var(x) as u128) * R3) becomes
//Self(mont_red_var(mont_red_var(x) as u128)) // Variable time implementation
Self(mont_red_cst(mont_red_cst(x) as u128)) // Constant time implementation
}
}

Expand Down Expand Up @@ -499,7 +493,7 @@ impl<'a> TryFrom<&'a [u8]> for BaseElement {
value
)));
}
Ok(BaseElement::new(value))
Ok(Self::new(value))
}
}

Expand Down Expand Up @@ -530,39 +524,10 @@ impl Deserializable for BaseElement {
value
)));
}
Ok(BaseElement::new(value))
Ok(Self::new(value))
}
}

// HELPER FUNCTIONS
// ================================================================================================

/// Reduces a 128-bit value by M such that the output is in [0, 2^64) range.
///
/// Adapted from: <https://github.com/mir-protocol/plonky2/blob/main/src/field/goldilocks_field.rs>
#[inline(always)]
fn mod_reduce(x: u128) -> u64 {
// assume x consists of four 32-bit values: a, b, c, d such that a contains 32 least
// significant bits and d contains 32 most significant bits. we break x into corresponding
// values as shown below
let ab = x as u64;
let cd = (x >> 64) as u64;
let c = (cd as u32) as u64;
let d = cd >> 32;

// compute ab - d; because d may be greater than ab we need to handle potential underflow
let (tmp0, under) = ab.overflowing_sub(d);
let tmp0 = tmp0.wrapping_sub(E * (under as u64));

// compute c * 2^32 - c; this is guaranteed not to underflow
let tmp1 = (c << 32) - c;

// add temp values and return the result; because each of the temp may be up to 64 bits,
// we need to handle potential overflow
let (result, over) = tmp0.overflowing_add(tmp1);
result.wrapping_add(E * (over as u64))
}

/// Squares the base N number of times and multiplies the result by the tail value.
#[inline(always)]
fn exp_acc<const N: usize>(base: BaseElement, tail: BaseElement) -> BaseElement {
Expand All @@ -572,3 +537,39 @@ fn exp_acc<const N: usize>(base: BaseElement, tail: BaseElement) -> BaseElement
}
result * tail
}

/// Montgomery reduction (variable time)
#[inline(always)]
pub const fn mont_red_var(x: u128) -> u64 {
const NPRIME: u64 = 4294967297;
let q = (((x as u64) as u128) * (NPRIME as u128)) as u64;
let m = (q as u128) * (M as u128);
let y = (((x as i128).wrapping_sub(m as i128)) >> 64) as i64;
if x < m {
return (y + (M as i64)) as u64;
} else {
return y as u64;
};
}

/// Montgomery reduction (constant time)
#[inline(always)]
pub const fn mont_red_cst(x: u128) -> u64 {
// See reference above for a description of the following implementation.
let xl = x as u64;
let xh = (x >> 64) as u64;
let (a, e) = xl.overflowing_add(xl << 32);

let b = a.wrapping_sub(a >> 32).wrapping_sub(e as u64);

let (r, c) = xh.overflowing_sub(b);
r.wrapping_sub(0u32.wrapping_sub(c as u32) as u64)
}

/// Test of equality between two BaseField elements; return value is
/// 0xFFFFFFFFFFFFFFFF if the two values are equal, or 0 otherwise.
#[inline(always)]
pub fn equals(lhs: u64, rhs: u64) -> u64 {
let t = lhs ^ rhs;
!((((t | t.wrapping_neg()) as i64) >> 63) as u64)
}
Loading