From e306358b8192169682eeeae69335e28e538b4992 Mon Sep 17 00:00:00 2001 From: Sam Tay Date: Tue, 20 Feb 2024 13:27:40 -0500 Subject: [PATCH] Update TFHE logproof bounds to use bits Also use u32 over u64, as std::lib does for various "bit length" operations like Shl, log2, et. --- logproof/benches/linear_relation.rs | 2 +- logproof/src/bfv_statement.rs | 10 ++-- logproof/src/linear_relation.rs | 72 ++++++++++++++--------------- logproof/src/math.rs | 26 ++++++----- sunscreen_runtime/src/builder.rs | 2 +- sunscreen_runtime/src/linked.rs | 2 +- sunscreen_tfhe/src/zkp.rs | 10 ++-- 7 files changed, 63 insertions(+), 61 deletions(-) diff --git a/logproof/benches/linear_relation.rs b/logproof/benches/linear_relation.rs index fc2bc3301..b11e67b82 100644 --- a/logproof/benches/linear_relation.rs +++ b/logproof/benches/linear_relation.rs @@ -76,7 +76,7 @@ where // e_1 = q / 2p // c_1 = s * a + e_1 + del * m // c_2 = a - const BIT_SIZE: usize = 2 << 8; + const BIT_SIZE: u32 = 2 << 8; println!("Generating data..."); diff --git a/logproof/src/bfv_statement.rs b/logproof/src/bfv_statement.rs index 490714108..5173da677 100644 --- a/logproof/src/bfv_statement.rs +++ b/logproof/src/bfv_statement.rs @@ -26,12 +26,12 @@ use crate::{ }; /// In SEAL, `u` is sampled from a ternary distribution. The number of bits is 1. -const U_COEFFICIENT_BOUND: usize = 1; +const U_COEFFICIENT_BOUND: u32 = 1; /// In SEAL, `e` is sampled from a centered binomial distribution with std dev 3.2, and a maximum /// width multiplier of 6, so max bound is 19.2. 19.2.ceil_log2() == 5 -const E_COEFFICIENT_BOUND: usize = 5; +const E_COEFFICIENT_BOUND: u32 = 5; /// In SEAL, secret keys are sampled from a ternary distribution. The number of bits is 1. -const S_COEFFICIENT_BOUND: usize = 1; +const S_COEFFICIENT_BOUND: u32 = 1; /// A proof statement verifying that a ciphertext is an encryption of a known plaintext message. /// Note that these statements are per SEAL plain/ciphertexts, where Sunscreen encodings are at a @@ -440,7 +440,7 @@ where let degree = params.degree() as usize; // calculate bounds - let m_default_bound = Bounds(vec![params.plain_modulus().ceil_log2() as usize; degree]); + let m_default_bound = Bounds(vec![params.plain_modulus().ceil_log2(); degree]); let r_bound = m_default_bound.clone(); let u_bound = Bounds(vec![U_COEFFICIENT_BOUND; degree]); let e_bound = Bounds(vec![E_COEFFICIENT_BOUND; degree]); @@ -449,7 +449,7 @@ where let q_div_2_bits = calculate_ciphertext_modulus(params.ciphertext_modulus()) .div(NonZero::from_uint(Uint::from(2u8))) .ceil_log2(); - let decrypt_e_bound = Bounds(vec![q_div_2_bits as usize; degree]); + let decrypt_e_bound = Bounds(vec![q_div_2_bits as u32; degree]); // insert them for i in 0..IdxOffsets::num_messages(statements) { diff --git a/logproof/src/linear_relation.rs b/logproof/src/linear_relation.rs index 3e1531277..e12524610 100644 --- a/logproof/src/linear_relation.rs +++ b/logproof/src/linear_relation.rs @@ -53,7 +53,7 @@ type MatrixPoly = Matrix>; * Bounds on the coefficients in the secret S (specified in number of bits). */ #[derive(Clone, Debug, PartialEq)] -pub struct Bounds(pub Vec); +pub struct Bounds(pub Vec); impl Zero for Bounds { // The empty vector could be seen as no bounds. Also follows the field @@ -69,7 +69,7 @@ impl Zero for Bounds { } impl std::ops::Deref for Bounds { - type Target = [usize]; + type Target = [u32]; fn deref(&self) -> &Self::Target { &self.0 @@ -132,22 +132,22 @@ where /** * The number of rows in a. */ - pub fn n(&self) -> u64 { - self.a.rows as u64 + pub fn n(&self) -> u32 { + self.a.rows as u32 } /** * The number of cols in a and the number rows in s. */ - pub fn m(&self) -> u64 { - self.a.cols as u64 + pub fn m(&self) -> u32 { + self.a.cols as u32 } /** * The number of cols in t. */ - pub fn k(&self) -> u64 { - self.t.cols as u64 + pub fn k(&self) -> u32 { + self.t.cols as u32 } /** @@ -168,11 +168,11 @@ where /** * Sum of all the bounds */ - pub fn b_sum(&self) -> u64 { + pub fn b_sum(&self) -> u32 { self.b() .as_slice() .iter() - .map(|v| v.iter().sum::() as u64) + .map(|v| v.iter().sum::()) .sum() } @@ -189,7 +189,7 @@ where let mut last_end_range = 0; for (k, b_piece) in b.as_slice().iter().enumerate() { - let bits = b_piece.iter().sum::(); + let bits = b_piece.iter().sum::() as usize; // Get the orginal matrix index let i = k / self.bounds.cols; @@ -210,28 +210,28 @@ where /** * The degree of `f`. */ - pub fn d(&self) -> u64 { - self.f.vartime_degree() as u64 + pub fn d(&self) -> u32 { + self.f.vartime_degree() as u32 } /** * Number of coefficients in secret vector s */ - pub fn number_coeff_in_s(&self) -> u64 { + pub fn number_coeff_in_s(&self) -> u32 { self.m() * self.d() } /** * Computes the nk(d-1)b_2 term in l. */ - pub fn nk_d_min_1_b_2(&self) -> u64 { + pub fn nk_d_min_1_b_2(&self) -> u32 { self.n() * self.k() * (self.d() - 1) * self.b_2() } /** * Computes the nk(2d-1)b_1 term in l. */ - pub fn nk_2d_min_1_b_1(&self) -> u64 { + pub fn nk_2d_min_1_b_1(&self) -> u32 { self.n() * self.k() * (2 * self.d() - 1) * self.b_1() } @@ -245,7 +245,7 @@ where for r in 0..self.bounds.rows { column_bound_sum += self.bounds[(r, c)] .iter() - .map(|b| if *b > 0 { 2u64.pow(*b as u32) } else { 0 }) + .map(|b| if *b > 0 { 2u64.pow(*b) } else { 0 }) .sum::(); } column_bound_sum @@ -256,7 +256,7 @@ where /** * The number of bits needed to store the elements of R1. */ - pub fn b_1(&self) -> u64 { + pub fn b_1(&self) -> u32 { let d_big = ZqRistretto::from(self.d()); let max_bounds_column_sum = ZqRistretto::from(self.max_bounds_column_sum()); @@ -270,14 +270,14 @@ where /** * The number of bits needed to store values in `Fp`. */ - pub fn b_2(&self) -> u64 { + pub fn b_2(&self) -> u32 { Log2::ceil_log2(&Q::field_modulus()) } /** * The length in bits of the binary expansion of the serialized secret * vectors. */ - pub fn l(&self) -> u64 { + pub fn l(&self) -> u32 { let total_bounds_all_equations = self.b_sum(); let nk = self.n().checked_mul(self.k()).unwrap(); @@ -496,19 +496,19 @@ impl LogProof { let r_1_serialized = Self::serialize(&r_1, (2 * d - 1) as usize); let r_2_serialized = Self::serialize(&r_2, (d - 1) as usize); - assert_eq!(s_serialized.len() as u64, m * k * d); + assert_eq!(s_serialized.len() as u32, m * k * d); - assert_eq!(r_1_serialized.len() as u64, n * k * (2 * d - 1)); - assert_eq!(r_2_serialized.len() as u64, n * k * (d - 1)); + assert_eq!(r_1_serialized.len() as u32, n * k * (2 * d - 1)); + assert_eq!(r_2_serialized.len() as u32, n * k * (d - 1)); let s_binary: BitVec = Self::to_2s_complement_multibound(&s_serialized, &b_serialized); - assert_eq!(s_binary.len() as u64, total_bounds_all_equations); + assert_eq!(s_binary.len() as u32, total_bounds_all_equations); let r_1_binary = Self::to_2s_complement(&r_1_serialized, b_1); - assert_eq!(r_1_binary.len() as u64, n * k * (2 * d - 1) * b_1); + assert_eq!(r_1_binary.len() as u32, n * k * (2 * d - 1) * b_1); let r_2_binary = Self::to_2s_complement(&r_2_serialized, b_2); - assert_eq!(r_2_binary.len() as u64, n * k * (d - 1) * b_2); + assert_eq!(r_2_binary.len() as u32, n * k * (d - 1) * b_2); let mut s_1 = s_binary.clone(); s_1.extend(r_1_binary.iter()); @@ -903,7 +903,7 @@ impl LogProof { two_b.as_slice(), ); - assert_eq!(term_1.len() as u64, b_sum); + assert_eq!(term_1.len() as u32, b_sum); // Compute term 2 let q = ZqRistretto::try_from(Q::field_modulus()).unwrap(); @@ -919,7 +919,7 @@ impl LogProof { .tensor(alpha_2d_minus_1) .tensor(two_b_1); - assert_eq!(term_2.len() as u64, b_1 * (2 * d - 1) * n * k); + assert_eq!(term_2.len() as u32, b_1 * (2 * d - 1) * n * k); // Compute term 3 let d_min_1 = d as usize - 1; @@ -935,7 +935,7 @@ impl LogProof { .tensor(alpha_d_minus_1) .tensor(two_b_2); - assert_eq!(term_3.len() as u64, b_2 * (d - 1) * n * k); + assert_eq!(term_3.len() as u32, b_2 * (d - 1) * n * k); let mut result = vec![]; @@ -1121,7 +1121,7 @@ impl LogProof { * panic on any other input. * */ - fn to_2s_complement_single(value: &Zq, log_b: u64, bitvec: &mut BitVec) + fn to_2s_complement_single(value: &Zq, log_b: u32, bitvec: &mut BitVec) where B: ArithmeticBackend, { @@ -1168,7 +1168,7 @@ impl LogProof { * `value` is the element in Zq and `b` is the number of bits needed * to represent the signed value. */ - fn to_2s_complement(values: &[Zq], log_b: u64) -> BitVec + fn to_2s_complement(values: &[Zq], log_b: u32) -> BitVec where B: ArithmeticBackend, { @@ -1192,14 +1192,14 @@ impl LogProof { * `value` is the element in Zq and `b` is the number of bits needed * to represent the signed value. */ - fn to_2s_complement_multibound(values: &[Zq], log_b: &[u64]) -> BitVec + fn to_2s_complement_multibound(values: &[Zq], log_b: &[u32]) -> BitVec where B: ArithmeticBackend, { // Make sure we have an equal number of values and bounds to serialize assert_eq!(values.len(), log_b.len()); - let mut bitvec = BitVec::with_capacity(log_b.iter().sum::() as usize); + let mut bitvec = BitVec::with_capacity(log_b.iter().sum::() as usize); // This code should not feature timing side-channels. for (value, bound) in zip(values.iter(), log_b.iter()) { @@ -1216,11 +1216,11 @@ impl LogProof { * The matrix is serialized in row-major order, with bound * coefficients being contiguous. */ - pub fn serialize_bounds(bounds: &Matrix) -> Vec { + pub fn serialize_bounds(bounds: &Matrix) -> Vec { bounds .as_slice() .iter() - .flat_map(|x| x.0.iter().map(|x| *x as u64)) + .flat_map(|x| x.0.iter().copied()) .collect() } @@ -1403,7 +1403,7 @@ mod test { 0 } else { next_higher_power_of_two(x.unsigned_abs()).ilog2() - as usize + as u32 } }) .collect::>(), diff --git a/logproof/src/math.rs b/logproof/src/math.rs index c8dcec561..0bf1b780d 100644 --- a/logproof/src/math.rs +++ b/logproof/src/math.rs @@ -1,6 +1,6 @@ use std::{borrow::Borrow, ops::Mul}; -use crypto_bigint::Uint; +use crypto_bigint::{Limb, Uint}; use curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::VartimeMultiscalarMul}; use rand::Rng; use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; @@ -139,7 +139,7 @@ pub trait Log2 { * # Panics * When the given value is zero. */ - fn log2(&self) -> u64; + fn log2(&self) -> u32; /** * Compute the ceiling of the log2 of the given value. @@ -147,14 +147,14 @@ pub trait Log2 { * # Panics * When the given value is zero. */ - fn ceil_log2(&self) -> u64; + fn ceil_log2(&self) -> u32; } impl Log2 for u64 { /** * An implementation of log2 that works on stable. */ - fn log2(&self) -> u64 { + fn log2(&self) -> u32 { let mut mask = 0x8000_0000_0000_0000; for i in 0..64 { @@ -168,7 +168,7 @@ impl Log2 for u64 { panic!("Value was zero."); } - fn ceil_log2(&self) -> u64 { + fn ceil_log2(&self) -> u32 { let ceil_factor = if self.is_power_of_two() { 0 } else { 1 }; self.log2() + ceil_factor } @@ -179,7 +179,9 @@ fn is_power_of_two_bigint(b: &Uint) -> bool { } impl Log2 for Uint { - fn log2(&self) -> u64 { + fn log2(&self) -> u32 { + let limb_bits = Limb::BITS as u32; + for i in 0..self.as_limbs().len() { let i = self.as_limbs().len() - i - 1; let limb = self.as_limbs()[i]; @@ -188,13 +190,13 @@ impl Log2 for Uint { continue; } - return Log2::log2(&limb.0) + (i as u64) * 64; + return Log2::log2(&limb.0) + (i as u32) * limb_bits; } panic!("Value was zero."); } - fn ceil_log2(&self) -> u64 { + fn ceil_log2(&self) -> u32 { let ceil_factor = if is_power_of_two_bigint(self) { 0 } else { 1 }; self.log2() + ceil_factor @@ -202,11 +204,11 @@ impl Log2 for Uint { } impl> Log2 for Zq { - fn log2(&self) -> u64 { + fn log2(&self) -> u32 { Uint::::log2(&self.val) } - fn ceil_log2(&self) -> u64 { + fn ceil_log2(&self) -> u32 { Uint::::ceil_log2(&self.val) } } @@ -663,7 +665,7 @@ mod test { for value in options { let f = value as f64; - let expected = f.log2().ceil() as u64; + let expected = f.log2().ceil() as u32; let calculated = Log2::ceil_log2(&value); @@ -677,7 +679,7 @@ mod test { for value in options { let f = value as f64; - let expected = f.log2().ceil() as u64; + let expected = f.log2().ceil() as u32; let b: Uint<1> = Uint::from(value); let calculated = Log2::ceil_log2(&b); diff --git a/sunscreen_runtime/src/builder.rs b/sunscreen_runtime/src/builder.rs index 77b16ce04..755a50db0 100644 --- a/sunscreen_runtime/src/builder.rs +++ b/sunscreen_runtime/src/builder.rs @@ -742,7 +742,7 @@ mod linked { fn mk_bounds(&self) -> Bounds { let params = self.runtime.params(); - let mut bounds = vec![params.plain_modulus.ceil_log2() as usize; P::DEGREE_BOUND]; + let mut bounds = vec![params.plain_modulus.ceil_log2(); P::DEGREE_BOUND]; bounds.resize(params.lattice_dimension as usize, 0); Bounds(bounds) } diff --git a/sunscreen_runtime/src/linked.rs b/sunscreen_runtime/src/linked.rs index a9dff3bd7..3da8315e6 100644 --- a/sunscreen_runtime/src/linked.rs +++ b/sunscreen_runtime/src/linked.rs @@ -453,7 +453,7 @@ impl SealSdlpVerifierKnowledge { /// Get the length in bits of the binary expansion of the serialized secret * vectors. /// /// Delegate to [`LogProofVerifierKnowledge::l`]. - pub fn l(&self) -> u64 { + pub fn l(&self) -> u32 { seq_zq!({ match &self.0 { #( diff --git a/sunscreen_tfhe/src/zkp.rs b/sunscreen_tfhe/src/zkp.rs index 3fdd49b53..97073af83 100644 --- a/sunscreen_tfhe/src/zkp.rs +++ b/sunscreen_tfhe/src/zkp.rs @@ -221,7 +221,7 @@ fn compute_bounds( // Bounds for messages for i in 0..num_messages { let mut b = Bounds(vec![0; num_coeffs]); - b.0[0] = 0x1u64 << plaintext_bits.0; + b.0[0] = 0x1_u32 << plaintext_bits.0; debug_assert_eq!(bounds[(i, 0)].0, &[]); bounds[(i, 0)] = b; } @@ -232,7 +232,7 @@ fn compute_bounds( let mut b = Bounds(vec![0; num_coeffs]); // Values of r are binary - b.0[0] = 0x1u64 << plaintext_bits.0; + b.0[0] = 0x1_u32 << plaintext_bits.0; debug_assert_eq!( bounds[(offsets.public_keys + i * lwe_dimension + j, 0)].0, &[] @@ -242,7 +242,7 @@ fn compute_bounds( // e is normal distributed over the torus. // TODO: This bound is too high. Get a tighter bound. - let b = Bounds(vec![0x1u64 << (60 - plaintext_bits.0); num_coeffs]); + let b = Bounds(vec![0x1_u32 << (60 - plaintext_bits.0); num_coeffs]); debug_assert_eq!(bounds[(offsets.public_e + i, 0)].0, &[]); bounds[(offsets.public_e + i, 0)] = b; @@ -254,7 +254,7 @@ fn compute_bounds( let mut b = Bounds(vec![0; num_coeffs]); // Values of s are binary - b.0[0] = 0x1u64 << plaintext_bits.0; + b.0[0] = 0x1_u32 << plaintext_bits.0; debug_assert_eq!( bounds[(offsets.private_a + j + i * lwe_dimension, 0)].0, &[] @@ -266,7 +266,7 @@ fn compute_bounds( // e is normal distributed over the torus. // TODO: This bound is too high. Get a tighter bound. - b.0[0] = 0x1u64 << (62 - plaintext_bits.0); + b.0[0] = 0x1_u32 << (62 - plaintext_bits.0); debug_assert_eq!(bounds[(offsets.private_e + i, 0)].0, &[]); bounds[(offsets.private_e + i, 0)] = b; }