Skip to content

Commit

Permalink
Update TFHE logproof bounds to use bits
Browse files Browse the repository at this point in the history
Also use u32 over u64, as std::lib does for various "bit length" operations like Shl, log2, et.
  • Loading branch information
samtay committed Feb 20, 2024
1 parent 1d2c5b6 commit e306358
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 61 deletions.
2 changes: 1 addition & 1 deletion logproof/benches/linear_relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ where
// e_1 = q / 2p
// c_1 = s * a + e_1 + del * m
// c_2 = a
const BIT_SIZE: usize = 2 << 8;
const BIT_SIZE: u32 = 2 << 8;

println!("Generating data...");

Expand Down
10 changes: 5 additions & 5 deletions logproof/src/bfv_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ use crate::{
};

/// In SEAL, `u` is sampled from a ternary distribution. The number of bits is 1.
const U_COEFFICIENT_BOUND: usize = 1;
const U_COEFFICIENT_BOUND: u32 = 1;
/// In SEAL, `e` is sampled from a centered binomial distribution with std dev 3.2, and a maximum
/// width multiplier of 6, so max bound is 19.2. 19.2.ceil_log2() == 5
const E_COEFFICIENT_BOUND: usize = 5;
const E_COEFFICIENT_BOUND: u32 = 5;
/// In SEAL, secret keys are sampled from a ternary distribution. The number of bits is 1.
const S_COEFFICIENT_BOUND: usize = 1;
const S_COEFFICIENT_BOUND: u32 = 1;

/// A proof statement verifying that a ciphertext is an encryption of a known plaintext message.
/// Note that these statements are per SEAL plain/ciphertexts, where Sunscreen encodings are at a
Expand Down Expand Up @@ -440,7 +440,7 @@ where
let degree = params.degree() as usize;

// calculate bounds
let m_default_bound = Bounds(vec![params.plain_modulus().ceil_log2() as usize; degree]);
let m_default_bound = Bounds(vec![params.plain_modulus().ceil_log2(); degree]);
let r_bound = m_default_bound.clone();
let u_bound = Bounds(vec![U_COEFFICIENT_BOUND; degree]);
let e_bound = Bounds(vec![E_COEFFICIENT_BOUND; degree]);
Expand All @@ -449,7 +449,7 @@ where
let q_div_2_bits = calculate_ciphertext_modulus(params.ciphertext_modulus())
.div(NonZero::from_uint(Uint::from(2u8)))
.ceil_log2();
let decrypt_e_bound = Bounds(vec![q_div_2_bits as usize; degree]);
let decrypt_e_bound = Bounds(vec![q_div_2_bits as u32; degree]);

// insert them
for i in 0..IdxOffsets::num_messages(statements) {
Expand Down
72 changes: 36 additions & 36 deletions logproof/src/linear_relation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type MatrixPoly<Q> = Matrix<Polynomial<Q>>;
* Bounds on the coefficients in the secret S (specified in number of bits).
*/
#[derive(Clone, Debug, PartialEq)]
pub struct Bounds(pub Vec<usize>);
pub struct Bounds(pub Vec<u32>);

impl Zero for Bounds {
// The empty vector could be seen as no bounds. Also follows the field
Expand All @@ -69,7 +69,7 @@ impl Zero for Bounds {
}

impl std::ops::Deref for Bounds {
type Target = [usize];
type Target = [u32];

fn deref(&self) -> &Self::Target {
&self.0
Expand Down Expand Up @@ -132,22 +132,22 @@ where
/**
* The number of rows in a.
*/
pub fn n(&self) -> u64 {
self.a.rows as u64
pub fn n(&self) -> u32 {
self.a.rows as u32
}

/**
* The number of cols in a and the number rows in s.
*/
pub fn m(&self) -> u64 {
self.a.cols as u64
pub fn m(&self) -> u32 {
self.a.cols as u32
}

/**
* The number of cols in t.
*/
pub fn k(&self) -> u64 {
self.t.cols as u64
pub fn k(&self) -> u32 {
self.t.cols as u32
}

/**
Expand All @@ -168,11 +168,11 @@ where
/**
* Sum of all the bounds
*/
pub fn b_sum(&self) -> u64 {
pub fn b_sum(&self) -> u32 {
self.b()
.as_slice()
.iter()
.map(|v| v.iter().sum::<usize>() as u64)
.map(|v| v.iter().sum::<u32>())
.sum()
}

Expand All @@ -189,7 +189,7 @@ where
let mut last_end_range = 0;

for (k, b_piece) in b.as_slice().iter().enumerate() {
let bits = b_piece.iter().sum::<usize>();
let bits = b_piece.iter().sum::<u32>() as usize;

// Get the orginal matrix index
let i = k / self.bounds.cols;
Expand All @@ -210,28 +210,28 @@ where
/**
* The degree of `f`.
*/
pub fn d(&self) -> u64 {
self.f.vartime_degree() as u64
pub fn d(&self) -> u32 {
self.f.vartime_degree() as u32
}

/**
* Number of coefficients in secret vector s
*/
pub fn number_coeff_in_s(&self) -> u64 {
pub fn number_coeff_in_s(&self) -> u32 {
self.m() * self.d()
}

/**
* Computes the nk(d-1)b_2 term in l.
*/
pub fn nk_d_min_1_b_2(&self) -> u64 {
pub fn nk_d_min_1_b_2(&self) -> u32 {
self.n() * self.k() * (self.d() - 1) * self.b_2()
}

/**
* Computes the nk(2d-1)b_1 term in l.
*/
pub fn nk_2d_min_1_b_1(&self) -> u64 {
pub fn nk_2d_min_1_b_1(&self) -> u32 {
self.n() * self.k() * (2 * self.d() - 1) * self.b_1()
}

Expand All @@ -245,7 +245,7 @@ where
for r in 0..self.bounds.rows {
column_bound_sum += self.bounds[(r, c)]
.iter()
.map(|b| if *b > 0 { 2u64.pow(*b as u32) } else { 0 })
.map(|b| if *b > 0 { 2u64.pow(*b) } else { 0 })
.sum::<u64>();
}
column_bound_sum
Expand All @@ -256,7 +256,7 @@ where
/**
* The number of bits needed to store the elements of R1.
*/
pub fn b_1(&self) -> u64 {
pub fn b_1(&self) -> u32 {
let d_big = ZqRistretto::from(self.d());
let max_bounds_column_sum = ZqRistretto::from(self.max_bounds_column_sum());

Expand All @@ -270,14 +270,14 @@ where
/**
* The number of bits needed to store values in `Fp<Q>`.
*/
pub fn b_2(&self) -> u64 {
pub fn b_2(&self) -> u32 {
Log2::ceil_log2(&Q::field_modulus())
}

/**
* The length in bits of the binary expansion of the serialized secret * vectors.
*/
pub fn l(&self) -> u64 {
pub fn l(&self) -> u32 {
let total_bounds_all_equations = self.b_sum();
let nk = self.n().checked_mul(self.k()).unwrap();

Expand Down Expand Up @@ -496,19 +496,19 @@ impl LogProof {
let r_1_serialized = Self::serialize(&r_1, (2 * d - 1) as usize);
let r_2_serialized = Self::serialize(&r_2, (d - 1) as usize);

assert_eq!(s_serialized.len() as u64, m * k * d);
assert_eq!(s_serialized.len() as u32, m * k * d);

assert_eq!(r_1_serialized.len() as u64, n * k * (2 * d - 1));
assert_eq!(r_2_serialized.len() as u64, n * k * (d - 1));
assert_eq!(r_1_serialized.len() as u32, n * k * (2 * d - 1));
assert_eq!(r_2_serialized.len() as u32, n * k * (d - 1));

let s_binary: BitVec = Self::to_2s_complement_multibound(&s_serialized, &b_serialized);
assert_eq!(s_binary.len() as u64, total_bounds_all_equations);
assert_eq!(s_binary.len() as u32, total_bounds_all_equations);

let r_1_binary = Self::to_2s_complement(&r_1_serialized, b_1);
assert_eq!(r_1_binary.len() as u64, n * k * (2 * d - 1) * b_1);
assert_eq!(r_1_binary.len() as u32, n * k * (2 * d - 1) * b_1);

let r_2_binary = Self::to_2s_complement(&r_2_serialized, b_2);
assert_eq!(r_2_binary.len() as u64, n * k * (d - 1) * b_2);
assert_eq!(r_2_binary.len() as u32, n * k * (d - 1) * b_2);

let mut s_1 = s_binary.clone();
s_1.extend(r_1_binary.iter());
Expand Down Expand Up @@ -903,7 +903,7 @@ impl LogProof {
two_b.as_slice(),
);

assert_eq!(term_1.len() as u64, b_sum);
assert_eq!(term_1.len() as u32, b_sum);

// Compute term 2
let q = ZqRistretto::try_from(Q::field_modulus()).unwrap();
Expand All @@ -919,7 +919,7 @@ impl LogProof {
.tensor(alpha_2d_minus_1)
.tensor(two_b_1);

assert_eq!(term_2.len() as u64, b_1 * (2 * d - 1) * n * k);
assert_eq!(term_2.len() as u32, b_1 * (2 * d - 1) * n * k);

// Compute term 3
let d_min_1 = d as usize - 1;
Expand All @@ -935,7 +935,7 @@ impl LogProof {
.tensor(alpha_d_minus_1)
.tensor(two_b_2);

assert_eq!(term_3.len() as u64, b_2 * (d - 1) * n * k);
assert_eq!(term_3.len() as u32, b_2 * (d - 1) * n * k);

let mut result = vec![];

Expand Down Expand Up @@ -1121,7 +1121,7 @@ impl LogProof {
* panic on any other input.
*
*/
fn to_2s_complement_single<B, const N: usize>(value: &Zq<N, B>, log_b: u64, bitvec: &mut BitVec)
fn to_2s_complement_single<B, const N: usize>(value: &Zq<N, B>, log_b: u32, bitvec: &mut BitVec)
where
B: ArithmeticBackend<N>,
{
Expand Down Expand Up @@ -1168,7 +1168,7 @@ impl LogProof {
* `value` is the element in Zq and `b` is the number of bits needed
* to represent the signed value.
*/
fn to_2s_complement<B, const N: usize>(values: &[Zq<N, B>], log_b: u64) -> BitVec
fn to_2s_complement<B, const N: usize>(values: &[Zq<N, B>], log_b: u32) -> BitVec
where
B: ArithmeticBackend<N>,
{
Expand All @@ -1192,14 +1192,14 @@ impl LogProof {
* `value` is the element in Zq and `b` is the number of bits needed
* to represent the signed value.
*/
fn to_2s_complement_multibound<B, const N: usize>(values: &[Zq<N, B>], log_b: &[u64]) -> BitVec
fn to_2s_complement_multibound<B, const N: usize>(values: &[Zq<N, B>], log_b: &[u32]) -> BitVec
where
B: ArithmeticBackend<N>,
{
// Make sure we have an equal number of values and bounds to serialize
assert_eq!(values.len(), log_b.len());

let mut bitvec = BitVec::with_capacity(log_b.iter().sum::<u64>() as usize);
let mut bitvec = BitVec::with_capacity(log_b.iter().sum::<u32>() as usize);

// This code should not feature timing side-channels.
for (value, bound) in zip(values.iter(), log_b.iter()) {
Expand All @@ -1216,11 +1216,11 @@ impl LogProof {
* The matrix is serialized in row-major order, with bound
* coefficients being contiguous.
*/
pub fn serialize_bounds(bounds: &Matrix<Bounds>) -> Vec<u64> {
pub fn serialize_bounds(bounds: &Matrix<Bounds>) -> Vec<u32> {
bounds
.as_slice()
.iter()
.flat_map(|x| x.0.iter().map(|x| *x as u64))
.flat_map(|x| x.0.iter().copied())
.collect()
}

Expand Down Expand Up @@ -1403,7 +1403,7 @@ mod test {
0
} else {
next_higher_power_of_two(x.unsigned_abs()).ilog2()
as usize
as u32
}
})
.collect::<Vec<_>>(),
Expand Down
26 changes: 14 additions & 12 deletions logproof/src/math.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{borrow::Borrow, ops::Mul};

use crypto_bigint::Uint;
use crypto_bigint::{Limb, Uint};
use curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::VartimeMultiscalarMul};
use rand::Rng;
use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
Expand Down Expand Up @@ -139,22 +139,22 @@ pub trait Log2 {
* # Panics
* When the given value is zero.
*/
fn log2(&self) -> u64;
fn log2(&self) -> u32;

/**
* Compute the ceiling of the log2 of the given value.
*
* # Panics
* When the given value is zero.
*/
fn ceil_log2(&self) -> u64;
fn ceil_log2(&self) -> u32;
}

impl Log2 for u64 {
/**
* An implementation of log2 that works on stable.
*/
fn log2(&self) -> u64 {
fn log2(&self) -> u32 {
let mut mask = 0x8000_0000_0000_0000;

for i in 0..64 {
Expand All @@ -168,7 +168,7 @@ impl Log2 for u64 {
panic!("Value was zero.");
}

fn ceil_log2(&self) -> u64 {
fn ceil_log2(&self) -> u32 {
let ceil_factor = if self.is_power_of_two() { 0 } else { 1 };
self.log2() + ceil_factor
}
Expand All @@ -179,7 +179,9 @@ fn is_power_of_two_bigint<const N: usize>(b: &Uint<N>) -> bool {
}

impl<const N: usize> Log2 for Uint<N> {
fn log2(&self) -> u64 {
fn log2(&self) -> u32 {
let limb_bits = Limb::BITS as u32;

for i in 0..self.as_limbs().len() {
let i = self.as_limbs().len() - i - 1;
let limb = self.as_limbs()[i];
Expand All @@ -188,25 +190,25 @@ impl<const N: usize> Log2 for Uint<N> {
continue;
}

return Log2::log2(&limb.0) + (i as u64) * 64;
return Log2::log2(&limb.0) + (i as u32) * limb_bits;
}

panic!("Value was zero.");
}

fn ceil_log2(&self) -> u64 {
fn ceil_log2(&self) -> u32 {
let ceil_factor = if is_power_of_two_bigint(self) { 0 } else { 1 };

self.log2() + ceil_factor
}
}

impl<const N: usize, B: ArithmeticBackend<N>> Log2 for Zq<N, B> {
fn log2(&self) -> u64 {
fn log2(&self) -> u32 {
Uint::<N>::log2(&self.val)
}

fn ceil_log2(&self) -> u64 {
fn ceil_log2(&self) -> u32 {
Uint::<N>::ceil_log2(&self.val)
}
}
Expand Down Expand Up @@ -663,7 +665,7 @@ mod test {

for value in options {
let f = value as f64;
let expected = f.log2().ceil() as u64;
let expected = f.log2().ceil() as u32;

let calculated = Log2::ceil_log2(&value);

Expand All @@ -677,7 +679,7 @@ mod test {

for value in options {
let f = value as f64;
let expected = f.log2().ceil() as u64;
let expected = f.log2().ceil() as u32;

let b: Uint<1> = Uint::from(value);
let calculated = Log2::ceil_log2(&b);
Expand Down
Loading

0 comments on commit e306358

Please sign in to comment.