Skip to content

Commit

Permalink
Rweber/multi pbs (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
rickwebiii authored Mar 1, 2024
1 parent d2dfd45 commit a256e42
Show file tree
Hide file tree
Showing 6 changed files with 380 additions and 87 deletions.
6 changes: 3 additions & 3 deletions sunscreen_tfhe/src/entities/bootstrap_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
dst! {
/// Keys used for bootstrapping. The [BootstrapKeyFft] variant of this type
/// is used by the bootstrapping functions such as
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap).
/// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate).
BootstrapKey,
BootstrapKeyRef,
Torus,
Expand Down Expand Up @@ -116,7 +116,7 @@ impl<S: TorusOps> BootstrapKeyRef<S> {

dst! {
/// Keys used for bootstrapping. Used by the bootstrapping functions such as
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap).
/// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate).
/// The non-FFT variant of this type is [BootstrapKey].
BootstrapKeyFft,
BootstrapKeyFftRef,
Expand All @@ -140,7 +140,7 @@ impl BootstrapKeyFft<Complex<f64>> {
/// encrypts a single bit of an LWE secret key. In this representation, the
/// GGSW ciphertexts are in the frequency domain and can be used directly by
/// the bootstrapping functions such as
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap).
/// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate).
pub fn new(lwe_params: &LweDef, glwe_params: &GlweDef, radix: &RadixDecomposition) -> Self {
let len = BootstrapKeyFftRef::size((lwe_params.dim, glwe_params.dim, radix.count));

Expand Down
41 changes: 35 additions & 6 deletions sunscreen_tfhe/src/entities/univariate_lookup_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::GlweCiphertextRef;

dst! {
/// Lookup table for a univariate function used during
/// [`programmable_bootstrap`](crate::ops::bootstrapping::programmable_bootstrap)
/// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate)
/// and [`circuit_bootstrap`](crate::ops::bootstrapping::circuit_bootstrap).
UnivariateLookupTable,
UnivariateLookupTableRef,
Expand All @@ -31,7 +31,11 @@ impl<S: TorusOps> OverlaySize for UnivariateLookupTableRef<S> {
}

impl<S: TorusOps> UnivariateLookupTable<S> {
/// Creates a lookup table that is trivially encrypted.
/// Creates a trivially encrypted lookup table that computes a single function `map`.
///
/// # Remarks
/// The result of this can be used with
/// [`programmable_bootstrap_univariate`](crate::ops::bootstrapping::programmable_bootstrap_univariate).
pub fn trivial_from_fn<F>(map: F, glwe: &GlweDef, plaintext_bits: PlaintextBits) -> Self
where
F: Fn(u64) -> u64,
Expand All @@ -40,7 +44,32 @@ impl<S: TorusOps> UnivariateLookupTable<S> {
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
};

lut.fill_trivial_from_fn(map, glwe, plaintext_bits);
lut.fill_trivial_from_fns(&[map], glwe, plaintext_bits);

lut
}

/// Creates a trivially encrypted lookup table that computes multiple functions
/// given by `maps`.
///
/// # Remarks
/// The result of this should be used with
/// [`generalized_programmable_bootstrap`](crate::ops::bootstrapping::generalized_programmable_bootstrap).
pub fn trivivial_multifunctional<F>(
maps: &[F],
glwe: &GlweDef,
plaintext_bits: PlaintextBits,
) -> Self
where
F: Fn(u64) -> u64,
{
assert!(maps.len() > 1);

let mut lut = UnivariateLookupTable {
data: vec![Torus::zero(); UnivariateLookupTableRef::<S>::size(glwe.dim)],
};

lut.fill_trivial_from_fns(maps, glwe, plaintext_bits);

lut
}
Expand All @@ -60,15 +89,15 @@ impl<S: TorusOps> UnivariateLookupTableRef<S> {

/// Generates a look up table filled with the values from the provided map,
/// and trivially encrypts the lookup table.
pub fn fill_trivial_from_fn<F: Fn(u64) -> u64>(
pub fn fill_trivial_from_fns<F: Fn(u64) -> u64>(
&mut self,
map: F,
maps: &[F],
glwe: &GlweDef,
plaintext_bits: PlaintextBits,
) {
allocate_scratch_ref!(poly, PolynomialRef<Torus<S>>, (glwe.dim.polynomial_degree));

generate_lut(poly, map, glwe, plaintext_bits);
generate_lut(poly, maps, glwe, plaintext_bits);

trivially_encrypt_glwe_ciphertext(self.glwe_mut(), poly, glwe);
}
Expand Down
6 changes: 3 additions & 3 deletions sunscreen_tfhe/src/high_level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ pub mod keygen {
/// However, anyone who possesses `glwe_key` can easily use the returned
/// [`BootstrapKey`] to recover `sk`.
pub fn generate_bootstrapping_key(
sk: &LweSecretKey<u64>,
glwe_key: &GlweSecretKey<u64>,
sk: &LweSecretKeyRef<u64>,
glwe_key: &GlweSecretKeyRef<u64>,
lwe: &LweDef,
glwe: &GlweDef,
radix: &RadixDecomposition,
Expand Down Expand Up @@ -844,7 +844,7 @@ pub mod evaluation {
) -> LweCiphertext<u64> {
let mut out = LweCiphertext::new(&glwe.as_lwe_def());

crate::ops::bootstrapping::programmable_bootstrap(
crate::ops::bootstrapping::programmable_bootstrap_univariate(
&mut out, input, lut, bsk, lwe, glwe, radix,
);

Expand Down
111 changes: 83 additions & 28 deletions sunscreen_tfhe/src/ops/bootstrapping/circuit_bootstrapping.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use num::Complex;
use sunscreen_math::Zero;

use crate::{
dst::FromMutSlice,
entities::{
BootstrapKeyFftRef, CircuitBootstrappingKeyswitchKeysRef, GgswCiphertextRef,
LweCiphertextListRef, LweCiphertextRef, UnivariateLookupTableRef,
GlweCiphertextRef, LweCiphertextListRef, LweCiphertextRef, UnivariateLookupTableRef,
},
ops::{
bootstrapping::programmable_bootstrap, homomorphisms::rotate,
bootstrapping::generalized_programmable_bootstrap, ciphertext::sample_extract,
homomorphisms::rotate,
keyswitch::private_functional_keyswitch::private_functional_keyswitch,
},
scratch::allocate_scratch_ref,
Expand Down Expand Up @@ -206,13 +208,11 @@ fn level_0_to_level_2<S: TorusOps>(
pbs_radix: &RadixDecomposition,
cbs_radix: &RadixDecomposition,
) {
allocate_scratch_ref!(glwe_out, GlweCiphertextRef<S>, (glwe_2.dim));
allocate_scratch_ref!(lut, UnivariateLookupTableRef<S>, (glwe_2.dim));
allocate_scratch_ref!(lwe_rotated, LweCiphertextRef<S>, (lwe_0.dim));
allocate_scratch_ref!(
lwe_bootstrapped,
LweCiphertextRef<S>,
(glwe_2.as_lwe_def().dim)
);
allocate_scratch_ref!(extracted, LweCiphertextRef<S>, (glwe_2.as_lwe_def().dim));
assert!(cbs_radix.count.0 < 8);

// Rotate our input by q/4, putting 0 centered on q/4 and 1 centered on
// -q/4.
Expand All @@ -223,42 +223,96 @@ fn level_0_to_level_2<S: TorusOps>(
lwe_0,
);

let log_v = if cbs_radix.count.0.is_power_of_two() {
cbs_radix.count.0.ilog2()
} else {
cbs_radix.count.0.ilog2() + 1
};

fill_multifunctional_cbs_decomposition_lut(lut, glwe_2, cbs_radix);

generalized_programmable_bootstrap(
glwe_out,
lwe_rotated,
lut,
bsk,
0,
log_v,
lwe_0,
glwe_2,
pbs_radix,
);

for (i, lwe_2) in lwes_2.ciphertexts_mut(&glwe_2.as_lwe_def()).enumerate() {
let cur_level = i + 1;

// Treat value as a T_{b^l+1} with one extra place for rounding as the last
// step.
let plaintext_bits = PlaintextBits((cbs_radix.radix_log.0 * cur_level + 1) as u32);

// Exploiting the fact that our LUT is negacyclic, we can encode -1 in T_{b^l+1}
// everywhere. Any lookup < q/2 will give -1 and any lookup > q/2 will
// give 1. Since we've shifted our input lwe by q/4, a 1 plaintext
// value will map to 1 and a 0 will map to -1.
let minus_one = (S::one() << plaintext_bits.0 as usize) - S::one();

lut.fill_with_constant(minus_one, glwe_2, plaintext_bits);

programmable_bootstrap(
lwe_bootstrapped,
lwe_rotated,
lut,
bsk,
lwe_0,
glwe_2,
pbs_radix,
);
sample_extract(extracted, glwe_out, i, glwe_2);

// Now we rotate our message containing -1 or 1 by 1 (wrt plaintext_bits).
// This will overflow -1 to 0 and cause 1 to wrap to 2.
rotate(
lwe_2,
lwe_bootstrapped,
extracted,
Torus::encode(S::one(), plaintext_bits),
&glwe_2.as_lwe_def(),
);
}
}

fn fill_multifunctional_cbs_decomposition_lut<S: TorusOps>(
lut: &mut UnivariateLookupTableRef<S>,
glwe: &GlweDef,
cbs_radix: &RadixDecomposition,
) {
lut.clear();

// Pick a largish number of levels nobody would ever exceed.
let mut levels = [Torus::zero(); 16];

assert!(cbs_radix.count.0 < levels.len());

// Compute our base decomposition factors.
// Exploiting the fact that our LUT is negacyclic, we can encode -1 in T_{b^l+1}
// everywhere. Any lookup < q/2 will give -1 and any lookup > q/2 will
// give 1. Since we've shifted our input lwe by q/4, a 1 plaintext
// value will map to 1 and a 0 will map to -1.
for (i, x) in levels.iter_mut().enumerate() {
let i = i + 1;
if i * cbs_radix.radix_log.0 + 1 < S::BITS as usize {
let plaintext_bits = PlaintextBits((cbs_radix.radix_log.0 * i + 1) as u32);

let minus_one = (S::one() << plaintext_bits.0 as usize) - S::one();
*x = Torus::encode(minus_one, plaintext_bits);
}
}

// Fill the table with alternating factors padded with zeros to a power of 2
let log_v = if cbs_radix.count.0.is_power_of_two() {
cbs_radix.count.0.ilog2()
} else {
cbs_radix.count.0.ilog2() + 1
};

let v = 0x1usize << log_v;

for (i, x) in lut
.glwe_mut()
.b_mut(glwe)
.coeffs_mut()
.iter_mut()
.enumerate()
{
let fn_id = i % v;

*x = if fn_id < cbs_radix.count.0 {
levels[fn_id]
} else {
Torus::zero()
};
}
}

/// Bootstraps a level 2 GLWE ciphertext to a level 1 GLWE ciphertext.
pub fn level_2_to_level1<S: TorusOps>(
result: &mut GgswCiphertextRef<S>,
Expand Down Expand Up @@ -322,6 +376,7 @@ mod tests {

let sk = keygen::generate_binary_lwe_sk(&TEST_LWE_DEF_1);
let glwe_sk = keygen::generate_binary_glwe_sk(&glwe_params);

let bsk = keygen::generate_bootstrapping_key(
&sk,
&glwe_sk,
Expand Down
Loading

0 comments on commit a256e42

Please sign in to comment.