diff --git a/examples/amm/src/main.rs b/examples/amm/src/main.rs index a734802dc..a415a52a1 100644 --- a/examples/amm/src/main.rs +++ b/examples/amm/src/main.rs @@ -1,7 +1,7 @@ use sunscreen::{ fhe_program, types::{bfv::Rational, Cipher}, - Ciphertext, CompiledFheProgram, Compiler, Error, FheRuntime, Params, PrivateKey, PublicKey, + CompiledFheProgram, Compiler, Error, FheRuntime, Params, PrivateKey, PublicKey, }; #[fhe_program(scheme = "bfv")] @@ -37,14 +37,10 @@ impl Miner { pub fn run_contract( &self, - nu_tokens_to_trade: Ciphertext, + nu_tokens_to_trade: Cipher, public_key: &PublicKey, - ) -> Result { - let results = - self.runtime - .run(&self.compiled_swap_nu, vec![nu_tokens_to_trade], public_key)?; - - Ok(results[0].clone()) + ) -> Result, Error> { + swap_nu.run(&self.runtime, public_key, nu_tokens_to_trade) } } @@ -73,13 +69,13 @@ impl Alice { }) } - pub fn create_transaction(&self, amount: f64) -> Result { + pub fn create_transaction(&self, amount: f64) -> Result, Error> { Ok(self .runtime .encrypt(Rational::try_from(amount)?, &self.public_key)?) } - pub fn check_received_eth(&self, received_eth: Ciphertext) -> Result<(), Error> { + pub fn check_received_eth(&self, received_eth: Cipher) -> Result<(), Error> { let received_eth: Rational = self.runtime.decrypt(&received_eth, &self.private_key)?; let received_eth: f64 = received_eth.into(); diff --git a/examples/calculator_fractional/src/main.rs b/examples/calculator_fractional/src/main.rs index 4b12ac7d7..34b4dd1f0 100644 --- a/examples/calculator_fractional/src/main.rs +++ b/examples/calculator_fractional/src/main.rs @@ -6,8 +6,7 @@ use std::thread::{self, JoinHandle}; use sunscreen::{ fhe_program, types::{bfv::Fractional, Cipher}, - Ciphertext, Compiler, FheApplication, FheRuntime, Params, PlainModulusConstraint, PublicKey, - RuntimeError, + Compiler, FheApplication, FheRuntime, Params, PlainModulusConstraint, PublicKey, RuntimeError, }; fn help() { @@ -26,7 +25,7 @@ fn help() { enum Term { Ans, F64(f64), - Encrypted(Ciphertext), + Encrypted(Cipher>), } #[derive(PartialEq)] @@ -134,7 +133,7 @@ fn alice( send_pub: Sender, send_calc: Sender, recv_params: Receiver, - recv_res: Receiver, + recv_res: Receiver>>, ) -> JoinHandle<()> { thread::spawn(move || { let stdin = io::stdin(); @@ -191,7 +190,7 @@ fn alice( .unwrap(); // Get our result from Bob and print it. - let result: Ciphertext = recv_res.recv().unwrap(); + let result: Cipher> = recv_res.recv().unwrap(); let result: Fractional<64> = match runtime.decrypt(&result, &private_key) { Ok(v) => v, Err(RuntimeError::TooMuchNoise) => { @@ -207,22 +206,22 @@ fn alice( }) } -fn compile_fhe_programs() -> FheApplication { - #[fhe_program(scheme = "bfv")] - fn add(a: Cipher>, b: Cipher>) -> Cipher> { - a + b - } +#[fhe_program(scheme = "bfv")] +fn add(a: Cipher>, b: Cipher>) -> Cipher> { + a + b +} - #[fhe_program(scheme = "bfv")] - fn sub(a: Cipher>, b: Cipher>) -> Cipher> { - a - b - } +#[fhe_program(scheme = "bfv")] +fn sub(a: Cipher>, b: Cipher>) -> Cipher> { + a - b +} - #[fhe_program(scheme = "bfv")] - fn mul(a: Cipher>, b: Cipher>) -> Cipher> { - a * b - } +#[fhe_program(scheme = "bfv")] +fn mul(a: Cipher>, b: Cipher>) -> Cipher> { + a * b +} +fn compile_fhe_programs() -> FheApplication { Compiler::new() .fhe_program(add) .fhe_program(sub) @@ -238,7 +237,7 @@ fn bob( recv_pub: Receiver, recv_calc: Receiver, send_params: Sender, - send_res: Sender, + send_res: Sender>>, ) -> JoinHandle<()> { thread::spawn(move || { let app = compile_fhe_programs(); @@ -268,41 +267,14 @@ fn bob( _ => panic!("Alice sent us a plaintext!"), }; - let mut c = match op { - Operand::Add => runtime - .run( - app.get_fhe_program("add").unwrap(), - vec![left, right], - &public_key, - ) - .unwrap(), - Operand::Sub => runtime - .run( - app.get_fhe_program("sub").unwrap(), - vec![left, right], - &public_key, - ) - .unwrap(), - Operand::Mul => runtime - .run( - app.get_fhe_program("mul").unwrap(), - vec![left, right], - &public_key, - ) - .unwrap(), - // To do division, Alice must send us 1 / b and we - // multiply. - Operand::Div => runtime - .run( - app.get_fhe_program("mul").unwrap(), - vec![left, right], - &public_key, - ) - .unwrap(), + let c = match op { + Operand::Add => add.run(&runtime, &public_key, left, right).unwrap(), + Operand::Sub => sub.run(&runtime, &public_key, left, right).unwrap(), + Operand::Mul => mul.run(&runtime, &public_key, left, right).unwrap(), + Operand::Div => mul.run(&runtime, &public_key, left, right).unwrap(), }; // Our FHE program produces a single value, so move the value out of the vector. - let c = c.drain(0..).next().unwrap(); ans = c.clone(); send_res.send(c).unwrap(); @@ -321,7 +293,8 @@ fn main() { let (send_bob_params, receive_bob_params) = std::sync::mpsc::channel::(); // A channel for Bob to send calculation results to Alice. - let (send_bob_result, receive_bob_result) = std::sync::mpsc::channel::(); + let (send_bob_result, receive_bob_result) = + std::sync::mpsc::channel::>>(); // We intentionally break Alice and Bob's roles into different functions to clearly // show the separation of their roles. In a real application, they're usually on diff --git a/examples/calculator_rational/src/main.rs b/examples/calculator_rational/src/main.rs index ded3805c0..84c37c81b 100644 --- a/examples/calculator_rational/src/main.rs +++ b/examples/calculator_rational/src/main.rs @@ -7,7 +7,7 @@ use sunscreen::FheRuntime; use sunscreen::{ fhe_program, types::{bfv::Rational, Cipher}, - Ciphertext, Compiler, FheApplication, Params, PlainModulusConstraint, PublicKey, RuntimeError, + Compiler, FheApplication, Params, PlainModulusConstraint, PublicKey, RuntimeError, }; fn help() { @@ -26,7 +26,7 @@ fn help() { enum Term { Ans, F64(f64), - Encrypted(Ciphertext), + Encrypted(Cipher), } enum Operand { @@ -116,7 +116,7 @@ fn alice( send_pub: Sender, send_calc: Sender, recv_params: Receiver, - recv_res: Receiver, + recv_res: Receiver>, ) -> JoinHandle<()> { thread::spawn(move || { let stdin = io::stdin(); @@ -173,7 +173,7 @@ fn alice( .unwrap(); // Get our result from Bob and print it. - let result: Ciphertext = recv_res.recv().unwrap(); + let result: Cipher = recv_res.recv().unwrap(); let result: Rational = match runtime.decrypt(&result, &private_key) { Ok(v) => v, Err(RuntimeError::TooMuchNoise) => { @@ -189,27 +189,27 @@ fn alice( }) } -fn compile_fhe_programs() -> FheApplication { - #[fhe_program(scheme = "bfv")] - fn add(a: Cipher, b: Cipher) -> Cipher { - a + b - } +#[fhe_program(scheme = "bfv")] +fn add(a: Cipher, b: Cipher) -> Cipher { + a + b +} - #[fhe_program(scheme = "bfv")] - fn sub(a: Cipher, b: Cipher) -> Cipher { - a - b - } +#[fhe_program(scheme = "bfv")] +fn sub(a: Cipher, b: Cipher) -> Cipher { + a - b +} - #[fhe_program(scheme = "bfv")] - fn mul(a: Cipher, b: Cipher) -> Cipher { - a * b - } +#[fhe_program(scheme = "bfv")] +fn mul(a: Cipher, b: Cipher) -> Cipher { + a * b +} - #[fhe_program(scheme = "bfv")] - fn div(a: Cipher, b: Cipher) -> Cipher { - a / b - } +#[fhe_program(scheme = "bfv")] +fn div(a: Cipher, b: Cipher) -> Cipher { + a / b +} +fn compile_fhe_programs() -> FheApplication { // We compile all the programs together so they have compatible // scheme parameters Compiler::new() @@ -228,7 +228,7 @@ fn bob( recv_pub: Receiver, recv_calc: Receiver, send_params: Sender, - send_res: Sender, + send_res: Sender>, ) -> JoinHandle<()> { thread::spawn(move || { let app = compile_fhe_programs(); @@ -258,39 +258,14 @@ fn bob( _ => panic!("Alice sent us a plaintext!"), }; - let mut c = match op { - Operand::Add => runtime - .run( - app.get_fhe_program("add").unwrap(), - vec![left, right], - &public_key, - ) - .unwrap(), - Operand::Sub => runtime - .run( - app.get_fhe_program("sub").unwrap(), - vec![left, right], - &public_key, - ) - .unwrap(), - Operand::Mul => runtime - .run( - app.get_fhe_program("mul").unwrap(), - vec![left, right], - &public_key, - ) - .unwrap(), - Operand::Div => runtime - .run( - app.get_fhe_program("div").unwrap(), - vec![left, right], - &public_key, - ) - .unwrap(), + let c = match op { + Operand::Add => add.run(&runtime, &public_key, left, right).unwrap(), + Operand::Sub => sub.run(&runtime, &public_key, left, right).unwrap(), + Operand::Mul => mul.run(&runtime, &public_key, left, right).unwrap(), + Operand::Div => div.run(&runtime, &public_key, left, right).unwrap(), }; // Our FHE program produces a single value, so move the value out of the vector. - let c = c.drain(0..).next().unwrap(); ans = c.clone(); send_res.send(c).unwrap(); @@ -309,7 +284,7 @@ fn main() { let (send_bob_params, receive_bob_params) = std::sync::mpsc::channel::(); // A channel for Bob to send calculation results to Alice. - let (send_bob_result, receive_bob_result) = std::sync::mpsc::channel::(); + let (send_bob_result, receive_bob_result) = std::sync::mpsc::channel::>(); // We intentionally break Alice and Bob's roles into different functions to clearly // show the separation of their roles. In a real application, they're usually on diff --git a/examples/chi_sq/src/main.rs b/examples/chi_sq/src/main.rs index b4c07af40..8834470df 100644 --- a/examples/chi_sq/src/main.rs +++ b/examples/chi_sq/src/main.rs @@ -169,7 +169,7 @@ fn run_fhe( ) -> Result<(), Error> where F: FheProgramFn + Clone + 'static + AsRef, - U: From + FheType + TypeName + std::fmt::Display, + U: From + FheType + TypeName + std::fmt::Display + 'static, { let start = Instant::now(); diff --git a/examples/pir/src/main.rs b/examples/pir/src/main.rs index 53800b0af..a8af0a49f 100644 --- a/examples/pir/src/main.rs +++ b/examples/pir/src/main.rs @@ -3,8 +3,7 @@ use sunscreen::{ fhe_program, types::{bfv::Signed, Cipher}, - Ciphertext, CompiledFheProgram, Compiler, Error, FheProgramInput, FheRuntime, Params, - PrivateKey, PublicKey, + CompiledFheProgram, Compiler, Error, FheRuntime, Params, PrivateKey, PublicKey, }; const SQRT_DATABASE_SIZE: usize = 10; @@ -67,10 +66,10 @@ impl Server { pub fn run_query( &self, - col_query: Ciphertext, - row_query: Ciphertext, + col_query: [Cipher; SQRT_DATABASE_SIZE], + row_query: [Cipher; SQRT_DATABASE_SIZE], public_key: &PublicKey, - ) -> Result { + ) -> Result, Error> { // Our database will consist of values between 400 and 500. let mut database = [[Signed::from(0); SQRT_DATABASE_SIZE]; SQRT_DATABASE_SIZE]; let mut val = Signed::from(400); @@ -82,11 +81,7 @@ impl Server { } } - let args: Vec = vec![col_query.into(), row_query.into(), database.into()]; - - let results = self.runtime.run(&self.compiled_lookup, args, public_key)?; - - Ok(results[0].clone()) + lookup.run(&self.runtime, public_key, col_query, row_query, database) } } @@ -116,7 +111,16 @@ impl Alice { }) } - pub fn create_query(&self, index: usize) -> Result<(Ciphertext, Ciphertext), Error> { + pub fn create_query( + &self, + index: usize, + ) -> Result< + ( + [Cipher; SQRT_DATABASE_SIZE], + [Cipher; SQRT_DATABASE_SIZE], + ), + Error, + > { let col = index % SQRT_DATABASE_SIZE; let row = index / SQRT_DATABASE_SIZE; @@ -126,12 +130,12 @@ impl Alice { row_query[row] = Signed::from(1); Ok(( - self.runtime.encrypt(col_query, &self.public_key)?, - self.runtime.encrypt(row_query, &self.public_key)?, + col_query.map(|x| self.runtime.encrypt(x, &self.public_key).unwrap()), + row_query.map(|x| self.runtime.encrypt(x, &self.public_key).unwrap()), )) } - pub fn check_response(&self, value: Ciphertext) -> Result<(), Error> { + pub fn check_response(&self, value: Cipher) -> Result<(), Error> { let value: Signed = self.runtime.decrypt(&value, &self.private_key)?; let value: i64 = value.into(); diff --git a/examples/private_tx_linkedproof/src/main.rs b/examples/private_tx_linkedproof/src/main.rs index c7f39db36..ffa18fe8b 100644 --- a/examples/private_tx_linkedproof/src/main.rs +++ b/examples/private_tx_linkedproof/src/main.rs @@ -17,8 +17,8 @@ use sunscreen::{ }, Cipher, }, - zkp_program, zkp_var, Ciphertext, CompiledFheProgram, CompiledZkpProgram, Compiler, - FheProgramInput, FheZkpApplication, FheZkpRuntime, Params, PrivateKey, PublicKey, Result, + zkp_program, zkp_var, CompiledFheProgram, CompiledZkpProgram, Compiler, FheZkpApplication, + FheZkpRuntime, Params, PrivateKey, PublicKey, Result, }; /// Subtract the transaction amount from the sender's balance. @@ -229,7 +229,7 @@ impl User { pub struct Register { proof: LinkedProof, public_key: PublicKey, - encrypted_amount: Ciphertext, + encrypted_amount: Cipher, deposit: Deposit, } @@ -249,9 +249,9 @@ pub struct Deposit { pub struct Transfer { proof: LinkedProof, // Transfer amount encrypted under sender's key - encrypted_amount_sender: Ciphertext, + encrypted_amount_sender: Cipher, // Transfer amount encrypted under receiver's key - encrypted_amount_receiver: Ciphertext, + encrypted_amount_receiver: Cipher, sender: Username, receiver: Username, } @@ -265,7 +265,7 @@ pub struct Transfer { #[derive(Clone)] pub struct RefreshBalance { proof: LinkedProof, - fresh_balance: Ciphertext, + fresh_balance: Cipher, name: Username, } @@ -285,7 +285,7 @@ pub enum Transaction { /// "send" functionalities, i.e. sending transactions that can mutate chain data. pub struct Chain { /// The current balances - balances: HashMap, + balances: HashMap>, /// The user's public keys keys: HashMap, /// Ledger of transactions @@ -345,14 +345,12 @@ impl Chain { // Deposit into the user's balance let pk = self.keys.get(&name).unwrap(); let curr_bal = self.balances.get_mut(&name).unwrap(); - *curr_bal = self - .runtime - .run::( - self.app.get_deposit_to_fhe(), - vec![curr_bal.clone().into(), Signed::from(public_amount).into()], - pk, - )? - .remove(0); + *curr_bal = deposit_to.run( + &self.runtime, + pk, + curr_bal.clone(), + Signed::from(public_amount), + )?; Ok(()) } @@ -386,26 +384,22 @@ impl Chain { // Update the sender's balance: let sender_pk = self.keys.get(&sender).unwrap(); let sender_balance = self.balances.get_mut(&sender).unwrap(); - *sender_balance = self - .runtime - .run( - self.app.get_transfer_from_fhe(), - vec![sender_balance.clone(), encrypted_amount_sender], - sender_pk, - )? - .remove(0); + *sender_balance = transfer_from.run( + &self.runtime, + sender_pk, + sender_balance.clone(), + encrypted_amount_sender, + )?; // Update receiver's balance let receiver_pk = self.keys.get(&receiver).unwrap(); let receiver_balance = self.balances.get_mut(&receiver).unwrap(); - *receiver_balance = self - .runtime - .run( - self.app.get_transfer_to_fhe(), - vec![receiver_balance.clone(), encrypted_amount_receiver], - receiver_pk, - )? - .remove(0); + *receiver_balance = transfer_to.run( + &self.runtime, + receiver_pk, + receiver_balance.clone(), + encrypted_amount_receiver, + )?; Ok(()) } diff --git a/examples/simple_multiply/src/main.rs b/examples/simple_multiply/src/main.rs index 73bf2f130..6cd37989f 100644 --- a/examples/simple_multiply/src/main.rs +++ b/examples/simple_multiply/src/main.rs @@ -1,7 +1,7 @@ use sunscreen::{ fhe_program, types::{bfv::Signed, Cipher}, - Compiler, Error, FheRuntime, + FheProgramFnExt, Result, }; /** @@ -22,31 +22,8 @@ fn simple_multiply(a: Cipher, b: Cipher) -> Cipher { a * b } -fn main() -> Result<(), Error> { - /* - * Here we compile the FHE program we previously declared. In the first step, - * we create our compiler, specify that we want to compile - * `simple_multiple`, and build it with the default settings. - * - * The `?` operator is Rust's standard - * error handling mechanism; it returns from the current function (`main`) - * when an error occurs (shouldn't happen). - * - * On success, compilation returns an [`Application`], which - * stores a group of FHE programs compiled under the same scheme parameters. - * These parameters are an implementation detail of FHE. - * While Sunscreen allows experts to explicitly set the scheme parameters, - * we're using the default behavior: automatically choose parameters - * yielding good performance while maintaining correctness. - */ - let app = Compiler::new().fhe_program(simple_multiply).compile()?; - - /* - * Next, we construct a runtime, which provides the APIs for encryption, - * decryption, and running an FHE program. We need to pass - * the scheme parameters our compiler chose. - */ - let runtime = FheRuntime::new(app.params())?; +fn main() -> Result<()> { + let runtime = simple_multiply.runtime()?; /* * Here, we generate a public and private key pair. Normally, Alice does this, @@ -57,22 +34,9 @@ fn main() -> Result<(), Error> { let a = runtime.encrypt(Signed::from(15), &public_key)?; let b = runtime.encrypt(Signed::from(5), &public_key)?; - /* - * Now, we run the FHE program with our arguments. This produces a results - * `Vec` containing the encrypted outputs of the FHE program. - */ - let results = runtime.run( - app.get_fhe_program(simple_multiply).unwrap(), - vec![a, b], - &public_key, - )?; - - /* - * Finally, we decrypt our program's output so we can check it. Our FHE - * program outputs a `Signed` single value as the result, so we just take - * the first element. - */ - let c: Signed = runtime.decrypt(&results[0], &private_key)?; + let spf = simple_multiply.as_spf(&public_key); + let result = spf(a, b)?; + let c: Signed = runtime.decrypt(&result, &private_key)?; /* * Yay, 5 * 15 indeed equals 75. @@ -87,7 +51,7 @@ mod tests { use super::*; #[test] - fn main_works() -> Result<(), Error> { + fn main_works() -> Result<()> { main() } } diff --git a/sunscreen/src/lib.rs b/sunscreen/src/lib.rs index 9816f1ab4..c946339dc 100644 --- a/sunscreen/src/lib.rs +++ b/sunscreen/src/lib.rs @@ -53,7 +53,7 @@ pub mod zkp; use fhe::{FheOperation, Literal}; use petgraph::stable_graph::StableGraph; use serde::{Deserialize, Serialize}; -use sunscreen_runtime::{marker, Fhe, FheZkp, Zkp}; +use sunscreen_runtime::{Fhe, FheZkp, Zkp}; use std::cell::RefCell; use std::collections::HashMap; @@ -66,10 +66,11 @@ pub use seal_fhe::Plaintext as SealPlaintext; pub use sunscreen_compiler_macros::*; pub use sunscreen_fhe_program::{SchemeType, SecurityLevel}; pub use sunscreen_runtime::{ - CallSignature, Ciphertext, CompiledFheProgram, CompiledZkpProgram, Error as RuntimeError, - FheProgramInput, FheProgramInputTrait, FheProgramMetadata, FheRuntime, FheZkpRuntime, - InnerCiphertext, InnerPlaintext, Params, Plaintext, PrivateKey, ProofBuilder, PublicKey, - RequiredKeys, Runtime, VerificationBuilder, WithContext, ZkpProgramInput, ZkpRuntime, + marker, CallSignature, Ciphertext, CompiledFheProgram, CompiledZkpProgram, + Error as RuntimeError, FheProgramInput, FheProgramMetadata, FheProgramPlaintextInput, + FheRuntime, FheZkpRuntime, GenericRuntime, InnerCiphertext, InnerPlaintext, Params, Plaintext, + PrivateKey, ProofBuilder, PublicKey, RequiredKeys, Runtime, VerificationBuilder, WithContext, + ZkpProgramInput, ZkpRuntime, }; #[cfg(feature = "bulletproofs")] pub use sunscreen_zkp_backend::bulletproofs; diff --git a/sunscreen/src/types/bfv/batched.rs b/sunscreen/src/types/bfv/batched.rs index a20053ce5..c111d3d1c 100644 --- a/sunscreen/src/types/bfv/batched.rs +++ b/sunscreen/src/types/bfv/batched.rs @@ -6,14 +6,14 @@ use crate::{ BfvType, FheType, LaneCount, NumCiphertexts, SwapRows, TryFromPlaintext, TryIntoPlaintext, Type, TypeName, TypeNameInstance, Version, }, - FheProgramInputTrait, InnerPlaintext, Params, Plaintext, WithContext, + FheProgramPlaintextInput, InnerPlaintext, Params, Plaintext, WithContext, }; use seal_fhe::{ BFVEncoder, BfvEncryptionParametersBuilder, Context as SealContext, Modulus, Result as SealResult, }; use std::ops::*; -use sunscreen_runtime::{Error as RuntimeError, Result as RuntimeResult}; +use sunscreen_runtime::{Error as RuntimeError, FheProgramInput, Result as RuntimeResult}; /** * A Batched vector of signed integers. The vector has 2 rows of `LANES` @@ -91,7 +91,12 @@ impl TypeNameInstance for Batched { } } -impl FheProgramInputTrait for Batched {} +impl FheProgramPlaintextInput for Batched {} +impl From> for FheProgramInput { + fn from(value: Batched) -> Self { + Self::Plaintext(Box::new(value)) + } +} impl FheType for Batched {} impl BfvType for Batched {} diff --git a/sunscreen/src/types/bfv/fractional.rs b/sunscreen/src/types/bfv/fractional.rs index 1cbe0b8b9..4f2513d2b 100644 --- a/sunscreen/src/types/bfv/fractional.rs +++ b/sunscreen/src/types/bfv/fractional.rs @@ -14,12 +14,12 @@ use crate::{ }; use crate::{ types::{intern::FheProgramNode, BfvType, FheType, Type, Version}, - FheProgramInputTrait, Params, WithContext, + FheProgramPlaintextInput, Params, WithContext, }; use sunscreen_runtime::{ - InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, TypeName, - TypeNameInstance, + FheProgramInput, InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, + TypeName, TypeNameInstance, }; use std::ops::*; @@ -174,7 +174,12 @@ impl NumCiphertexts for Fractional { const NUM_CIPHERTEXTS: usize = 1; } -impl FheProgramInputTrait for Fractional {} +impl FheProgramPlaintextInput for Fractional {} +impl From> for FheProgramInput { + fn from(value: Fractional) -> Self { + Self::Plaintext(Box::new(value)) + } +} impl Default for Fractional { fn default() -> Self { diff --git a/sunscreen/src/types/bfv/rational.rs b/sunscreen/src/types/bfv/rational.rs index 1aa0f09dc..5c1c7e514 100644 --- a/sunscreen/src/types/bfv/rational.rs +++ b/sunscreen/src/types/bfv/rational.rs @@ -4,10 +4,10 @@ use crate::types::{ bfv::Signed, intern::FheProgramNode, ops::*, BfvType, Cipher, FheType, NumCiphertexts, TryFromPlaintext, TryIntoPlaintext, TypeName, }; -use crate::{FheProgramInputTrait, InnerPlaintext, Params, Plaintext, TypeName}; +use crate::{FheProgramPlaintextInput, InnerPlaintext, Params, Plaintext, TypeName}; use std::cmp::Eq; use std::ops::*; -use sunscreen_runtime::Error; +use sunscreen_runtime::{impl_into_fhe_program_plaintext_input, Error}; use num::Rational64; @@ -96,7 +96,8 @@ impl TryIntoPlaintext for Rational { } } -impl FheProgramInputTrait for Rational {} +impl FheProgramPlaintextInput for Rational {} +impl_into_fhe_program_plaintext_input!(Rational); impl FheType for Rational {} impl BfvType for Rational {} diff --git a/sunscreen/src/types/bfv/signed.rs b/sunscreen/src/types/bfv/signed.rs index 967cbc4fe..c42fe6d53 100644 --- a/sunscreen/src/types/bfv/signed.rs +++ b/sunscreen/src/types/bfv/signed.rs @@ -15,11 +15,12 @@ use crate::{ }; use crate::{ types::{intern::FheProgramNode, BfvType, FheType, TypeNameInstance}, - FheProgramInputTrait, Params, TypeName as DeriveTypeName, WithContext, + FheProgramPlaintextInput, Params, TypeName as DeriveTypeName, WithContext, }; use sunscreen_runtime::{ - InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, + impl_into_fhe_program_plaintext_input, InnerPlaintext, NumCiphertexts, Plaintext, + TryFromPlaintext, TryIntoPlaintext, }; use std::ops::*; @@ -52,7 +53,8 @@ mod sharing { } } -impl FheProgramInputTrait for Signed {} +impl FheProgramPlaintextInput for Signed {} +impl_into_fhe_program_plaintext_input!(Signed); impl FheType for Signed {} impl BfvType for Signed {} diff --git a/sunscreen/src/types/bfv/unsigned.rs b/sunscreen/src/types/bfv/unsigned.rs index f2841bfcb..2daaa3464 100644 --- a/sunscreen/src/types/bfv/unsigned.rs +++ b/sunscreen/src/types/bfv/unsigned.rs @@ -5,7 +5,7 @@ use paste::paste; use seal_fhe::Plaintext as SealPlaintext; use sunscreen_runtime::{ - InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, + FheProgramInput, InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, }; use crate as sunscreen; @@ -23,7 +23,7 @@ use crate::{ }; use crate::{ types::{intern::FheProgramNode, BfvType, FheType, TypeNameInstance}, - FheProgramInputTrait, Params, TypeName as DeriveTypeName, WithContext, + FheProgramPlaintextInput, Params, TypeName as DeriveTypeName, WithContext, }; #[derive(Debug, Clone, Copy, DeriveTypeName, PartialEq, Eq)] @@ -38,7 +38,12 @@ impl NumCiphertexts for Unsigned { const NUM_CIPHERTEXTS: usize = 1; } -impl FheProgramInputTrait for Unsigned {} +impl FheProgramPlaintextInput for Unsigned {} +impl From> for FheProgramInput { + fn from(value: Unsigned) -> Self { + Self::Plaintext(Box::new(value)) + } +} impl FheType for Unsigned {} impl BfvType for Unsigned {} diff --git a/sunscreen/src/types/intern/fhe_program_node.rs b/sunscreen/src/types/intern/fhe_program_node.rs index 474012ef5..10e5ed055 100644 --- a/sunscreen/src/types/intern/fhe_program_node.rs +++ b/sunscreen/src/types/intern/fhe_program_node.rs @@ -12,7 +12,7 @@ use sunscreen_runtime::TypeNameInstance; use std::ops::{Add, Div, Mul, Neg, Shl, Shr, Sub}; -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq)] /** * A type that wraps an FheType during graph construction. It is an implementation * detail and you should not construct these directly. @@ -56,6 +56,17 @@ pub struct FheProgramNode { _phantom: std::marker::PhantomData, } +impl Copy for FheProgramNode {} +impl Clone for FheProgramNode { + fn clone(&self) -> Self { + Self { + ids: self.ids, + stage: self.stage.clone(), + _phantom: self._phantom, + } + } +} + impl FheProgramNode { /** * Creates a new FHE program node with the given node index. diff --git a/sunscreen/src/types/mod.rs b/sunscreen/src/types/mod.rs index 4eeb37fa1..491c898e7 100644 --- a/sunscreen/src/types/mod.rs +++ b/sunscreen/src/types/mod.rs @@ -75,7 +75,7 @@ mod ops; pub mod zkp; pub use sunscreen_runtime::{ - BfvType, FheType, NumCiphertexts, TryFromPlaintext, TryIntoPlaintext, Type, TypeName, + BfvType, Cipher, FheType, NumCiphertexts, TryFromPlaintext, TryIntoPlaintext, Type, TypeName, TypeNameInstance, Version, }; @@ -104,36 +104,6 @@ pub trait LaneCount { fn lane_count() -> usize; } -#[derive(Copy, Clone, Debug)] -/** - * Declares a type T as being encrypted in an [`fhe_program`](crate::fhe_program). - */ -pub struct Cipher -where - T: FheType, -{ - _val: T, -} - -impl NumCiphertexts for Cipher -where - T: FheType, -{ - const NUM_CIPHERTEXTS: usize = T::NUM_CIPHERTEXTS; -} - -impl TypeName for Cipher -where - T: FheType + TypeName, -{ - fn type_name() -> Type { - Type { - is_encrypted: true, - ..T::type_name() - } - } -} - /// Creates new FHE variables from literals. /// /// Note that literals can be used directly in arithmetic operations with ciphertexts: diff --git a/sunscreen/tests/array.rs b/sunscreen/tests/array.rs index 41bdccb90..ad53b4e51 100644 --- a/sunscreen/tests/array.rs +++ b/sunscreen/tests/array.rs @@ -91,12 +91,15 @@ fn multidimensional_arrays() { let result = runtime .run( app.get_fhe_program(determinant).unwrap(), - vec![a_c], + vec![a_c.clone()], &public_key, ) .unwrap(); - let c: Signed = runtime.decrypt(&result[0], &private_key).unwrap(); + let spf = determinant.as_spf(&public_key); + let res = spf(a_c).unwrap(); + + let c: Signed = runtime.decrypt(&res, &private_key).unwrap(); assert_eq!(c, Signed::from(-3)); assert_eq!(c, determinant_impl(matrix)); diff --git a/sunscreen/tests/sdlp.rs b/sunscreen/tests/sdlp.rs index a19f9be76..b7dd09dc1 100644 --- a/sunscreen/tests/sdlp.rs +++ b/sunscreen/tests/sdlp.rs @@ -94,11 +94,7 @@ mod sdlp_tests { fn double(x: sunscreen::types::Cipher) -> sunscreen::types::Cipher { x + x } - let double_compiled = double.compile().unwrap(); - let computed_ct = rt - .run(&double_compiled, vec![initial_ct], &public_key) - .unwrap() - .remove(0); + let computed_ct = double.run(&rt, &public_key, initial_ct).unwrap(); let mut logproof_builder = SdlpBuilder::new(&rt); let (_, msg) = logproof_builder diff --git a/sunscreen_compiler_macros/src/fhe_program.rs b/sunscreen_compiler_macros/src/fhe_program.rs index 957755cf1..ce2334195 100644 --- a/sunscreen_compiler_macros/src/fhe_program.rs +++ b/sunscreen_compiler_macros/src/fhe_program.rs @@ -118,6 +118,51 @@ impl<'a> FheProgram<'a> { .collect() } + // The arg types of the FnOnce returned from `as_spf`. + fn spf_arg_types(&self) -> Vec { + self.unwrapped_inputs + .iter() + .map(|(_, ty, _)| { + let ty = map_spf_type(ty).unwrap(); + quote! { #ty } + }) + .collect() + } + + // The return types of the FnOnce returned from `as_spf`. + fn spf_return_types(&self) -> Vec { + self.return_types + .iter() + .map(|ty| { + let ty = map_spf_type(ty).unwrap(); + quote! { #ty } + }) + .collect() + } + + // The closure args in the implementation of `as_spf`. + fn spf_args(&self) -> Vec { + self.unwrapped_inputs + .iter() + .map(|(_, _, name)| { + quote! { #name } + }) + .collect() + } + + // The closure args in the implementation of `as_spf`. + fn run_args(&self) -> Vec { + self.unwrapped_inputs + .iter() + .map(|(_, ty, name)| { + let ty = map_spf_type(ty).unwrap(); + quote! { + #name: #ty + } + }) + .collect() + } + // Variable declarations like (but not exactly): // `__c_0: FheProgramNode> = FheProgramNode::input()` fn fhe_arg_var_decl(&self) -> Vec { @@ -196,6 +241,13 @@ impl<'a> FheProgram<'a> { let fhe_program_name_literal = format!("{}", fhe_program_name); + let spf_arg_types = self.spf_arg_types(); + let spf_return_type = pack_into_tuple(&self.spf_return_types()); + let spf_args = self.spf_args(); + let spf_return_idents = inner_return_idents; + let spf_return_tuple = pack_into_tuple(&spf_return_idents); + let run_args = self.run_args(); + quote! { #[allow(non_camel_case_types)] #[derive(Clone)] @@ -289,6 +341,40 @@ impl<'a> FheProgram<'a> { } } + impl #fhe_program_struct_name { + pub fn as_spf<'a>( + &'a self, + public_key: &'a sunscreen::PublicKey, + ) -> impl FnOnce( + #(#spf_arg_types),* + ) -> sunscreen::Result<#spf_return_type> + 'a { + |#(#spf_args),*| { + let runtime = sunscreen::FheProgramFnExt::runtime(self)?; + self.run(&runtime, public_key, #(#spf_args),*) + } + } + + pub fn run<'a, T: sunscreen::marker::Fhe, B>( + &'a self, + runtime: &'a sunscreen::GenericRuntime, + public_key: &'a sunscreen::PublicKey, + #(#run_args),* + ) -> sunscreen::Result<#spf_return_type> { + // TODO cache this in a Lazy field? would need to export type from sunscreen lib + let prog = sunscreen::FheProgramFnExt::compile(self)?; + let inputs: ::std::vec::Vec = ::std::vec![ + #( + #spf_args.into(), + )* + ]; + let mut results = runtime.run(&prog, inputs, public_key)?; + #( + let #spf_return_idents = sunscreen::types::Cipher::cast(results.remove(0))?; + )*; + Ok(#spf_return_tuple) + } + } + #[allow(non_upper_case_globals)] #vis const #fhe_program_name: #fhe_program_struct_name = #fhe_program_struct_name { chain_count: #chain_count diff --git a/sunscreen_compiler_macros/src/fhe_program_transforms.rs b/sunscreen_compiler_macros/src/fhe_program_transforms.rs index 801eeb767..832d922e2 100644 --- a/sunscreen_compiler_macros/src/fhe_program_transforms.rs +++ b/sunscreen_compiler_macros/src/fhe_program_transforms.rs @@ -31,6 +31,15 @@ pub fn map_fhe_type(arg_type: &Type) -> Result { Ok(transformed_type) } +/** + * Given an input type T, returns + * * sunscreen::Ciphertext when T is a Cipher<_> or [[..[Cipher<_>; N];.. M]; Q] + * * T otherwise + */ +pub fn map_spf_type(arg_type: &Type) -> Result { + Ok(arg_type.clone()) +} + /** * Emits code to make an FHE program node for the given * type T. diff --git a/sunscreen_runtime/src/array.rs b/sunscreen_runtime/src/array.rs index 6bf1ab310..865e911c8 100644 --- a/sunscreen_runtime/src/array.rs +++ b/sunscreen_runtime/src/array.rs @@ -1,8 +1,9 @@ use crate::{ - Error, FheProgramInputTrait, InnerPlaintext, NumCiphertexts, Params, Plaintext, Result, - TryFromPlaintext, TryIntoPlaintext, Type, TypeName, TypeNameInstance, WithContext, + Ciphertext, Error, FheProgramCiphertextInput, FheProgramPlaintextInput, InnerCiphertext, + InnerPlaintext, IntoCiphertext, NumCiphertexts, Params, Plaintext, Result, TryFromPlaintext, + TryIntoPlaintext, Type, TypeName, TypeNameInstance, WithContext, }; -use seal_fhe::Plaintext as SealPlaintext; +use seal_fhe::{Ciphertext as SealCiphertext, Plaintext as SealPlaintext}; impl TryIntoPlaintext for [T; N] where @@ -68,7 +69,7 @@ where } } -impl FheProgramInputTrait for [T; N] where T: TypeName + TryIntoPlaintext {} +impl FheProgramPlaintextInput for [T; N] where T: TypeName + TryIntoPlaintext {} impl NumCiphertexts for [T; N] where @@ -76,3 +77,26 @@ where { const NUM_CIPHERTEXTS: usize = T::NUM_CIPHERTEXTS * N; } + +impl IntoCiphertext for [T; N] +where + T: IntoCiphertext, + Self: TypeName, +{ + fn into_ciphertext(&self) -> Ciphertext { + let element_ciphertexts = self + .iter() + .map(|v| v.into_ciphertext()) + .flat_map(|p| match p.inner { + InnerCiphertext::Seal(v) => v, + }) + .collect::>>(); + + Ciphertext { + inner: InnerCiphertext::Seal(element_ciphertexts), + data_type: Self::type_name(), + } + } +} + +impl FheProgramCiphertextInput for [T; N] where T: TypeName + IntoCiphertext {} diff --git a/sunscreen_runtime/src/builder.rs b/sunscreen_runtime/src/builder.rs index 15b7bd02c..b2baf579e 100644 --- a/sunscreen_runtime/src/builder.rs +++ b/sunscreen_runtime/src/builder.rs @@ -232,7 +232,7 @@ mod linked { }; use crate::{ - marker, Ciphertext, CompiledZkpProgram, Fhe, FheRuntime, FheZkp, FheZkpRuntime, + marker, Cipher, Ciphertext, CompiledZkpProgram, Fhe, FheRuntime, FheZkp, FheZkpRuntime, GenericRuntime, LinkedProof, NumCiphertexts, Params, Plaintext, PrivateKey, PublicKey, Result, Sdlp, SdlpProverKnowledge, SdlpVerifierKnowledge, TryFromPlaintext, TryIntoPlaintext, ZkpProgramInput, @@ -274,31 +274,61 @@ mod linked { /// /// Due to some quirks in the visibility warnings, this is marked `pub` and manually excluded /// from the `pub use linked::{}` export above. - #[derive(Clone, Debug)] - pub struct MessageInternal { + #[derive(Debug)] + pub struct MessageInternal { id: usize, len: usize, pt: Arc, zkp_type: Z, + _pt_marker: std::marker::PhantomData, + } + impl Clone for MessageInternal { + fn clone(&self) -> Self { + let MessageInternal { + id, + len, + pt, + zkp_type, + _pt_marker, + } = self; + MessageInternal { + id: *id, + len: *len, + pt: pt.clone(), + zkp_type: zkp_type.clone(), + _pt_marker: *_pt_marker, + } + } } /// A [`Plaintext`] message that can be [encrypted again](`LogProofBuilder::reencrypt`). #[derive(Clone, Debug)] - pub struct Message(MessageInternal<()>); + pub struct Message(MessageInternal); /// A [`Plaintext`] message that can be [encrypted again](`LogProofBuilder::reencrypt`) or /// [linked to a ZKP program](`LogProofBuilder::linked_input`). Create this with /// [`LogProofBuilder::encrypt_returning_link`]. #[derive(Debug)] - pub struct LinkedMessage(MessageInternal); + pub struct LinkedMessage(MessageInternal); - impl LinkedMessage { - fn from_message(msg: Message, zkp_type: Type) -> Self { + impl LinkedMessage { + fn from_message(msg: Message, zkp_type: Type) -> Self { LinkedMessage(MessageInternal { id: msg.0.id, len: msg.0.len, pt: msg.0.pt, zkp_type, + _pt_marker: std::marker::PhantomData, + }) + } + + fn coerce(self) -> LinkedMessage { + LinkedMessage(MessageInternal { + id: self.0.id, + len: self.0.len, + pt: self.0.pt, + zkp_type: self.0.zkp_type, + _pt_marker: std::marker::PhantomData, }) } } @@ -306,31 +336,32 @@ mod linked { mod private { pub trait Sealed {} - impl Sealed for super::Message {} - impl Sealed for super::LinkedMessage {} + impl Sealed for super::Message {} + impl Sealed for super::LinkedMessage {} } /// Indicates that the message is already added to the SDLP, and hence can be used as an /// argument to [`LogProofBuilder::reencrypt`]. - pub trait ExistingMessage: private::Sealed { + pub trait ExistingMessage: private::Sealed { /// Convert the message to the internal type. - fn as_internal(&self) -> MessageInternal<()>; + fn as_internal(&self) -> MessageInternal; } - impl ExistingMessage for Message { - fn as_internal(&self) -> MessageInternal<()> { + impl ExistingMessage for Message { + fn as_internal(&self) -> MessageInternal { self.0.clone() } } - impl ExistingMessage for LinkedMessage { - fn as_internal(&self) -> MessageInternal<()> { + impl ExistingMessage for LinkedMessage { + fn as_internal(&self) -> MessageInternal { let msg = self.0.clone(); MessageInternal { id: msg.id, len: msg.len, pt: msg.pt, zkp_type: (), + _pt_marker: std::marker::PhantomData, } } } @@ -405,7 +436,8 @@ mod linked { // linked proof fields compiled_zkp_program: Option<&'z CompiledZkpProgram>, - linked_inputs: Vec, + // we don't need the type information after recording the input + linked_inputs: Vec>, private_inputs: Vec, public_inputs: Vec, constant_inputs: Vec, @@ -468,7 +500,7 @@ mod linked { /// /// If you do not want to add the encryption statement to the proof, just use [the /// runtime](`crate::GenericRuntime::encrypt`) directly. - pub fn encrypt

(&mut self, message: &P, public_key: &'k PublicKey) -> Result + pub fn encrypt

(&mut self, message: &P, public_key: &'k PublicKey) -> Result> where P: TryIntoPlaintext + TypeName, { @@ -484,7 +516,7 @@ mod linked { &mut self, message: &P, private_key: &'k PrivateKey, - ) -> Result + ) -> Result> where P: TryIntoPlaintext + TypeName, { @@ -502,7 +534,7 @@ mod linked { &mut self, message: &P, public_key: &'k PublicKey, - ) -> Result<(Ciphertext, Message)> + ) -> Result<(Cipher

, Message

)> where P: TryIntoPlaintext + TypeName, { @@ -519,7 +551,7 @@ mod linked { &mut self, message: &P, private_key: &'k PrivateKey, - ) -> Result<(Ciphertext, Message)> + ) -> Result<(Cipher

, Message

)> where P: TryIntoPlaintext + TypeName, { @@ -531,7 +563,7 @@ mod linked { message: &P, key: Key<'k>, bounds: Option, - ) -> Result<(Ciphertext, Message)> + ) -> Result<(Cipher

, Message

)> where P: TryIntoPlaintext + TypeName, { @@ -555,6 +587,7 @@ mod linked { pt: Arc::new(plaintext_typed), len: idx_end - idx_start, zkp_type: (), + _pt_marker: std::marker::PhantomData, }; Ok((ct, Message(msg_internal))) } @@ -565,11 +598,11 @@ mod linked { /// plaintext. If this is not what you want, use [`Self::encrypt`]. /// /// This method assumes that you've created the `message` argument with _this_ builder. - pub fn reencrypt( + pub fn reencrypt>( &mut self, message: &E, public_key: &'k PublicKey, - ) -> Result { + ) -> Result> { // The existing message already has bounds, no need to recompute them. let bounds = None; self.encrypt_asymmetric_internal( @@ -586,11 +619,11 @@ mod linked { /// plaintext. If this is not what you want, use [`Self::encrypt_symmetric`]. /// /// This method assumes that you've created the `message` argument with _this_ builder. - pub fn reencrypt_symmetric( + pub fn reencrypt_symmetric>( &mut self, message: &E, private_key: &'k PrivateKey, - ) -> Result { + ) -> Result> { // The existing message already has bounds, no need to recompute them. let bounds = None; self.encrypt_symmetric_internal( @@ -614,9 +647,9 @@ mod linked { /// linked ZKP program. pub fn decrypt_returning_msg

( &mut self, - ciphertext: &Ciphertext, + ciphertext: &Cipher

, private_key: &'k PrivateKey, - ) -> Result<(P, Message)> + ) -> Result<(P, Message

)> where P: TryIntoPlaintext + TryFromPlaintext + TypeName, { @@ -625,10 +658,10 @@ mod linked { fn decrypt_internal

( &mut self, - ciphertext: &Ciphertext, + ciphertext: &Cipher

, private_key: &'k PrivateKey, bounds: Option, - ) -> Result<(P, Message)> + ) -> Result<(P, Message

)> where P: TryIntoPlaintext + TryFromPlaintext + TypeName, { @@ -665,20 +698,23 @@ mod linked { pt, len: end_idx - start_idx, zkp_type: (), + _pt_marker: std::marker::PhantomData, }; Ok((p, Message(msg_internal))) } - fn encrypt_asymmetric_internal( + fn encrypt_asymmetric_internal( &mut self, message: Msg, public_key: &'k PublicKey, bounds: Option, - ) -> Result { + ) -> Result> { let existing_idx = message.existing_id(); let mut i = 0; - self.runtime - .encrypt_map_components(&message, public_key, |m, ct, components| { + let ct = self.runtime.encrypt_map_components( + &message, + public_key, + |m, ct, components| { let message_id = if let Some(idx) = existing_idx { idx + i } else { @@ -698,18 +734,20 @@ mod linked { self.witness .push(BfvWitness::PublicKeyEncryption(components)); i += 1; - }) + }, + )?; + Ok(Cipher::new(ct.inner)) } - fn encrypt_symmetric_internal( + fn encrypt_symmetric_internal( &mut self, message: Msg, private_key: &'k PrivateKey, bounds: Option, - ) -> Result { + ) -> Result> { let existing_idx = message.existing_id(); let mut i = 0; - self.runtime.encrypt_symmetric_map_components( + let ct = self.runtime.encrypt_symmetric_map_components( &message, private_key, |m, ct, components| { @@ -734,7 +772,8 @@ mod linked { }); i += 1; }, - ) + )?; + Ok(Cipher::new(ct.inner)) } fn plaintext_typed

(&self, pt: &P) -> Result @@ -809,7 +848,7 @@ mod linked { &mut self, message: &P, public_key: &'k PublicKey, - ) -> Result<(Ciphertext, LinkedMessage)> + ) -> Result<(Cipher

, LinkedMessage

)> where P: LinkWithZkp + TryIntoPlaintext + TypeName, { @@ -833,7 +872,7 @@ mod linked { &mut self, message: &P, private_key: &'k PrivateKey, - ) -> Result<(Ciphertext, LinkedMessage)> + ) -> Result<(Cipher

, LinkedMessage

)> where P: LinkWithZkp + TryIntoPlaintext + TypeName, { @@ -858,9 +897,9 @@ mod linked { /// instead, and prove equality within a linked ZKP program. pub fn decrypt_returning_link

( &mut self, - ciphertext: &Ciphertext, + ciphertext: &Cipher

, private_key: &'k PrivateKey, - ) -> Result<(P, LinkedMessage)> + ) -> Result<(P, LinkedMessage

)> where P: LinkWithZkp + TryIntoPlaintext + TryFromPlaintext + TypeName, { @@ -891,8 +930,8 @@ mod linked { /// Add a linked private input to the ZKP program. /// /// This method assumes that you've created the `message` argument with _this_ builder. - pub fn linked_input(&mut self, message: LinkedMessage) -> &mut Self { - self.linked_inputs.push(message); + pub fn linked_input

(&mut self, message: LinkedMessage

) -> &mut Self { + self.linked_inputs.push(message.coerce()); self } diff --git a/sunscreen_runtime/src/lib.rs b/sunscreen_runtime/src/lib.rs index cbebb1b30..449943e3a 100644 --- a/sunscreen_runtime/src/lib.rs +++ b/sunscreen_runtime/src/lib.rs @@ -14,7 +14,7 @@ mod run; mod runtime; mod serialization; -use std::sync::Arc; +use std::{ops::Deref, sync::Arc}; use seal_fhe::{Ciphertext as SealCiphertext, Plaintext as SealPlaintext}; use serde::{Deserialize, Serialize}; @@ -177,6 +177,72 @@ impl InnerCiphertext { } } +#[derive(Clone, Deserialize, Serialize)] +/** + * A typed variant of [`Ciphertext`]. + */ +// TODO: possibly restrict T: FheType +pub struct Cipher { + /// The inner ciphertext + pub inner: Ciphertext, + _marker: std::marker::PhantomData, +} + +impl Deref for Cipher { + type Target = Ciphertext; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl Cipher { + pub(crate) fn new(inner: Ciphertext) -> Cipher { + Self { + inner, + _marker: std::marker::PhantomData, + } + } +} + +impl Cipher { + /// Cast a [`Ciphertext`] to a typed [`Cipher`]. Returns an error if the underlying + /// ciphertext datatype does not match `T`. + pub fn cast(ciphertext: Ciphertext) -> Result> { + let expected_type = Type { + is_encrypted: true, + ..T::type_name() + }; + if expected_type != ciphertext.data_type { + Err(Error::type_mismatch(&expected_type, &ciphertext.data_type)) + } else { + Ok(Self { + inner: ciphertext, + _marker: std::marker::PhantomData, + }) + } + } +} + +impl NumCiphertexts for Cipher +where + T: NumCiphertexts, +{ + const NUM_CIPHERTEXTS: usize = T::NUM_CIPHERTEXTS; +} + +impl TypeName for Cipher +where + T: FheType + TypeName, +{ + fn type_name() -> Type { + Type { + is_encrypted: true, + ..T::type_name() + } + } +} + #[derive(Clone, Deserialize, Serialize)] /** * An encryption of the given data type. Note, the data type is @@ -209,7 +275,10 @@ impl Ciphertext { * A trait that denotes this type can be used as an * argument to an FHE program. */ -pub trait FheProgramInputTrait: TryIntoPlaintext + TypeNameInstance {} +pub trait FheProgramPlaintextInput: TryIntoPlaintext + TypeNameInstance {} + +/// A trait that denotes this type can be used as a ciphertext argument to an FHE program. +pub trait FheProgramCiphertextInput: IntoCiphertext {} /** * An input argument to an Fhe Program. See [`crate::Runtime::run`]. @@ -218,12 +287,32 @@ pub enum FheProgramInput { /** * The argument is a ciphertext. */ - Ciphertext(Ciphertext), + Ciphertext(Box), /** * The argument is a plaintext. */ - Plaintext(Box), + Plaintext(Box), +} + +impl From for FheProgramInput { + fn from(value: T) -> Self { + Self::Ciphertext(Box::new(value)) + } +} + +// Without specialization, can't have two blanket impls, so provide: +/// A macro to cover the boilerplate for implementing `Into` for a plaintext +/// type.` +#[macro_export] +macro_rules! impl_into_fhe_program_plaintext_input { + ($t:ty) => { + impl From<$t> for sunscreen::FheProgramInput { + fn from(value: $t) -> Self { + Self::Plaintext(Box::new(value)) + } + } + }; } /** @@ -258,27 +347,12 @@ impl TypeNameInstance for ZkpProgramInput { impl TypeNameInstance for FheProgramInput { fn type_name_instance(&self) -> Type { match self { - Self::Ciphertext(c) => c.data_type.clone(), + Self::Ciphertext(c) => c.into_ciphertext().data_type.type_name_instance(), Self::Plaintext(p) => p.type_name_instance(), } } } -impl From for FheProgramInput { - fn from(val: Ciphertext) -> Self { - Self::Ciphertext(val) - } -} - -impl From for FheProgramInput -where - T: FheProgramInputTrait + 'static, -{ - fn from(val: T) -> Self { - Self::Plaintext(Box::new(val)) - } -} - /** * This trait denotes one may attempt to turn this type into a plaintext. */ @@ -295,6 +369,28 @@ impl TryIntoPlaintext for Plaintext { } } +/// This trait dentoes one may attempt to turn this type into a ciphertext. +#[allow(clippy::wrong_self_convention)] +pub trait IntoCiphertext { + /// Attempts to convert this type into a [`Ciphertext`]. + fn into_ciphertext(&self) -> Ciphertext; +} + +impl IntoCiphertext for Ciphertext { + fn into_ciphertext(&self) -> Ciphertext { + self.clone() + } +} + +impl IntoCiphertext for Cipher { + fn into_ciphertext(&self) -> Ciphertext { + self.inner.clone() + } +} + +impl FheProgramCiphertextInput for Ciphertext {} +impl FheProgramCiphertextInput for Cipher {} + /** * A trait for converting values into fields used by ZKPs. */ @@ -345,7 +441,7 @@ pub trait NumCiphertexts { * Denotes the given rust type is an encoding in an FHE scheme */ pub trait FheType: - TypeNameInstance + TryIntoPlaintext + TryFromPlaintext + FheProgramInputTrait + NumCiphertexts + TypeNameInstance + TryIntoPlaintext + TryFromPlaintext + FheProgramPlaintextInput + NumCiphertexts { } diff --git a/sunscreen_runtime/src/runtime.rs b/sunscreen_runtime/src/runtime.rs index d03913100..782695ec4 100644 --- a/sunscreen_runtime/src/runtime.rs +++ b/sunscreen_runtime/src/runtime.rs @@ -3,11 +3,11 @@ use std::time::Instant; use merlin::Transcript; -use crate::error::*; use crate::metadata::*; use crate::ProofBuilder; use crate::VerificationBuilder; use crate::ZkpProgramInput; +use crate::{error::*, Cipher}; use crate::{ run_program_unchecked, serialization::WithContext, Ciphertext, FheProgramInput, InnerCiphertext, InnerPlaintext, Plaintext, PrivateKey, PublicKey, SealCiphertext, SealData, @@ -140,7 +140,21 @@ where T: self::marker::Fhe, { /** - * Decrypts the given ciphertext into the type P. + * Decrypts the given ciphertext into the underlying type P. + */ + // TODO replace decrypt with this function, leaving the other named `decrypt_opaque`. + #[allow(unused)] + pub fn decrypt_TODO

(&self, ciphertext: &Cipher

, private_key: &PrivateKey) -> Result

+ where + P: TryFromPlaintext + TypeName, + { + let fhe_data = self.runtime_data.unwrap_fhe(); + let pt = self.decrypt_map_components::

(&ciphertext.inner, private_key, |_, _| ())?; + P::try_from_plaintext(&pt, &fhe_data.params) + } + + /** + * Decrypts the given opaque ciphertext into the provided type P. */ pub fn decrypt

(&self, ciphertext: &Ciphertext, private_key: &PrivateKey) -> Result

where @@ -353,13 +367,16 @@ where for i in arguments.drain(0..) { match i { - FheProgramInput::Ciphertext(c) => match c.inner { - InnerCiphertext::Seal(mut c) => { - for j in c.drain(0..) { - inputs.push(SealData::Ciphertext(j.data)); + FheProgramInput::Ciphertext(c) => { + let c = c.into_ciphertext(); + match c.inner { + InnerCiphertext::Seal(mut c) => { + for j in c.drain(0..) { + inputs.push(SealData::Ciphertext(j.data)); + } } } - }, + } FheProgramInput::Plaintext(p) => { let p = p.try_into_plaintext(&fhe_data.params)?; @@ -421,7 +438,7 @@ where * Returns [`Error::ParameterMismatch`] if the plaintext is incompatible with this runtime's * scheme. */ - pub fn encrypt

(&self, val: P, public_key: &PublicKey) -> Result + pub fn encrypt

(&self, val: P, public_key: &PublicKey) -> Result> where P: TryIntoPlaintext + TypeName, { @@ -442,7 +459,7 @@ where /// /// Returns [`Error::ParameterMismatch`] if the plaintext is incompatible with this runtime's /// scheme. - pub fn encrypt_symmetric

(&self, val: P, private_key: &PrivateKey) -> Result + pub fn encrypt_symmetric

(&self, val: P, private_key: &PrivateKey) -> Result> where P: TryIntoPlaintext + TypeName, { @@ -541,7 +558,7 @@ where val: &P, public_key: &PublicKey, mut f: impl FnMut(&SealPlaintext, &SealCiphertext, AsymmetricComponents), - ) -> Result + ) -> Result> where P: TryIntoPlaintext + TypeNameInstance, { @@ -574,7 +591,7 @@ where val: &P, private_key: &PrivateKey, mut f: impl FnMut(&SealPlaintext, &SealCiphertext, SymmetricComponents), - ) -> Result + ) -> Result> where P: TryIntoPlaintext + TypeNameInstance, { @@ -597,11 +614,11 @@ where // Use a seal encryption function to encrypt a list of inner seal plaintexts `pts`, // representing a runtime level plaintext of type `pt_type`, and return a runtime ciphertext // consisting of the list of respective inner seal ciphertexts. - fn aggregate_ciphertexts( + fn aggregate_ciphertexts( pt_type: &Type, pts: &[WithContext], mut enc_fn: F, - ) -> Result + ) -> Result> where F: FnMut(&SealPlaintext) -> seal_fhe::Result, { @@ -614,13 +631,13 @@ where }) }) .collect::>>()?; - Ok(Ciphertext { + Ok(Cipher::new(Ciphertext { data_type: Type { is_encrypted: true, ..pt_type.clone() }, inner: InnerCiphertext::Seal(cts), - }) + })) } }