diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 6fc1f149c..917de9fae 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -64,6 +64,15 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822" +[[package]] +name = "counter" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d84b66ce02c964fa8047286289b36797ce48a52a44034e013ce3e5219b6cb360" +dependencies = [ + "num-traits", +] + [[package]] name = "crc32fast" version = "1.2.0" @@ -799,6 +808,7 @@ dependencies = [ "bincode", "bitflags", "bytes", + "counter", "derive_more", "futures", "num", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 9bff48bc5..5b3301c01 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -12,19 +12,20 @@ homepage = "https://xain.io/" [dependencies] futures = "0.3.4" tokio = { version = "0.2.19", features = ["rt-core", "rt-threaded", "tcp", "time", "macros", "signal", "sync", "stream"] } -derive_more = { version = "0.99.3", default-features = false, features = [ "display", "from" , "as_ref", "as_mut"] } +derive_more = { version = "0.99.3", default-features = false, features = [ "display", "from" , "as_ref", "as_mut", "into", "index", "index_mut"] } rand = "0.7.3" rand_chacha = "0.2.2" serde = { version = "1.0.104", features = [ "derive" ] } bytes = "0.5.4" tracing = "0.1.13" sodiumoxide = "0.2.5" -num = "0.2.1" +num = { version = "0.2.1" } bincode = "1.2.1" thiserror = "1.0.16" anyhow = "1.0.28" bitflags = "1.2.1" paste = "0.1.12" +counter = "0.4.3" [[bin]] name = "coordinator" diff --git a/rust/src/certificate.rs b/rust/src/certificate.rs index cb3eea262..dbe3cbea7 100644 --- a/rust/src/certificate.rs +++ b/rust/src/certificate.rs @@ -1,54 +1,45 @@ -use derive_more::{AsMut, AsRef}; +use crate::PetError; -use crate::{crypto::ByteObject, PetError}; - -#[derive(AsRef, AsMut, Clone, Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq)] /// A dummy certificate. pub struct Certificate(Vec); -impl ByteObject for Certificate { - /// Create a certificate a slice of bytes. Fails if the length of the input is invalid. - fn from_slice(bytes: &[u8]) -> Option { - Some(Self(bytes.to_vec())) +#[allow(clippy::len_without_is_empty)] +impl Certificate { + #[allow(clippy::new_without_default)] + /// Create a certificate + pub fn new() -> Self { + Self(vec![0_u8; 32]) } - /// Create a certificate initialized to zero. - fn zeroed() -> Self { - Self(vec![0_u8; Self::BYTES]) + /// Get the length of the certificate. + pub fn len(&self) -> usize { + self.as_ref().len() } - /// Get the certificate as a slice. - fn as_slice(&self) -> &[u8] { - self.0.as_slice() + /// Validate a certificate + pub fn validate(&self) -> Result<(), PetError> { + Ok(()) } } -#[allow(clippy::len_without_is_empty)] -impl Certificate { - /// Get the number of bytes of a certificate. - pub const BYTES: usize = 32; - - /// Get the length of the serialized certificate. - pub fn len(&self) -> usize { - self.as_slice().len() - } - - /// Serialize the certificate into bytes. - pub fn serialize(&self) -> Vec { - self.as_slice().to_vec() +impl AsRef<[u8]> for Certificate { + /// Get a reference to the certificate. + fn as_ref(&self) -> &[u8] { + self.0.as_slice() } +} - /// Deserialize the certificate from bytes. Fails if the length of the input is invalid. - pub fn deserialize(bytes: &[u8]) -> Result { - Self::from_slice(bytes).ok_or(PetError::InvalidMessage) +impl From> for Certificate { + /// Create a certificate from bytes. + fn from(bytes: Vec) -> Self { + Self(bytes) } +} - /// Validate the certificate. - pub fn validate(&self) -> Result<(), PetError> { - if self.as_slice() == [0_u8; 32] { - Ok(()) - } else { - Err(PetError::InvalidMessage) - } +impl From<&[u8]> for Certificate { + /// Create a certificate from a slice of bytes. + fn from(slice: &[u8]) -> Self { + Self(slice.to_vec()) } } diff --git a/rust/src/coordinator.rs b/rust/src/coordinator.rs index 13e1afd9e..018e121a0 100644 --- a/rust/src/coordinator.rs +++ b/rust/src/coordinator.rs @@ -4,22 +4,18 @@ use std::{ default::Default, }; -use derive_more::{AsMut, AsRef}; -use sodiumoxide::{ - crypto::{box_, hash::sha256}, - randombytes::randombytes, -}; -use thiserror::Error; +use sodiumoxide::{self, crypto::hash::sha256, randombytes::randombytes}; use crate::{ crypto::{generate_encrypt_key_pair, ByteObject, SigningKeySeed}, - mask::{Integers, Mask, MaskIntegers, MaskedModel}, - message::{sum::SumMessage, sum2::Sum2Message, update::UpdateMessage}, - model::Model, + mask::Mask, + message::{MessageOpen, PayloadOwned, Sum2Owned, SumOwned, UpdateOwned}, + utils::is_eligible, CoordinatorPublicKey, CoordinatorSecretKey, InitError, LocalSeedDict, + ParticipantPublicKey, ParticipantTaskSignature, PetError, SeedDict, @@ -31,18 +27,19 @@ use crate::{ /// Error that occurs when the current round fails #[derive(Debug, Eq, PartialEq)] pub enum RoundFailed { - /// Round failed because ambiguous masks were computed by the sum participants. + /// Round failed because ambiguous masks were computed by + /// a majority of sum participants AmbiguousMasks, - /// Round failed because no mask was submitted by any sum participant. + /// Round failed because no mask hash was selected by any sum + /// participant NoMask, - /// Round failed because no model could be unmasked. - NoModel, } -/// A dictionary created during the sum2 phase of the protocol. It counts the model masks. +/// A dictionary created during the sum2 phase of the protocol. It counts the model masks +/// represented by their hashes. pub type MaskDict = HashMap; -#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[derive(Debug, PartialEq, Copy, Clone)] /// Round phases of a coordinator. pub enum Phase { Idle, @@ -51,60 +48,8 @@ pub enum Phase { Sum2, } -/// Events the protocol emits. -#[derive(Debug, PartialEq)] -pub enum ProtocolEvent { - /// The round starts with the given parameters. The coordinator is - /// now in the sum phase. - StartSum(RoundParameters), - - /// The sum phase finished and produced the given sum - /// dictionary. The coordinator is now in the update phase. - StartUpdate(SumDict), - - /// The update phase finished and produced the given seed - /// dictionary. The coordinator is now in the sum2 phase. - StartSum2(SeedDict), - - /// The sum2 phase finished and produced a global model. The - /// coordinator is now back to the idle phase. - EndRound(Option<()>), -} - -#[derive(AsRef, AsMut, Clone, Debug, PartialEq, Eq)] -/// A seed for a round. -pub struct RoundSeed(box_::Seed); - -impl ByteObject for RoundSeed { - /// Create a round seed from a slice of bytes. Fails if the length of the input is invalid. - fn from_slice(bytes: &[u8]) -> Option { - box_::Seed::from_slice(bytes).map(Self) - } - - /// Create a round seed initialized to zero. - fn zeroed() -> Self { - Self(box_::Seed([0_u8; Self::BYTES])) - } - - /// Get the round seed as a slice. - fn as_slice(&self) -> &[u8] { - self.0.as_ref() - } -} - -impl RoundSeed { - /// Get the number of bytes of a round seed. - pub const BYTES: usize = box_::SEEDBYTES; - - /// Generate a random round seed. - pub fn generate() -> Self { - // safe unwrap: length of slice is guaranteed by constants - Self::from_slice_unchecked(randombytes(Self::BYTES).as_slice()) - } -} - /// A coordinator in the PET protocol layer. -pub struct Coordinator { +pub struct Coordinator { // credentials pk: CoordinatorPublicKey, // 32 bytes sk: CoordinatorSecretKey, // 32 bytes @@ -112,7 +57,7 @@ pub struct Coordinator { // round parameters sum: f64, update: f64, - seed: RoundSeed, // 32 bytes + seed: Vec, // 32 bytes min_sum: usize, min_update: usize, phase: Phase, @@ -125,29 +70,43 @@ pub struct Coordinator { /// Dictionary built during the sum2 phase. mask_dict: MaskDict, - // global models - model: Option>, - masked_model: Option, - - /// Events emitted by the state machine. + /// Events emitted by the state machine events: VecDeque, } -impl Default for Coordinator { +/// Events the protocol emits. +#[derive(Debug)] +pub enum ProtocolEvent { + /// The round starts with the given parameters. The coordinator is + /// now in the sum phase. + StartSum(RoundParameters), + + /// The sum phase finished and produced the given sum + /// dictionary. The coordinator is now in the update phase. + StartUpdate(SumDict), + + /// The update phase finished and produced the given seed + /// dictionary. The coordinator is now in the sum2 phase. + StartSum2(SeedDict), + + /// The sum2 phase finished and produced the given mask seed. The + /// coordinator is now back to the idle phase. + EndRound(Option), +} + +impl Default for Coordinator { fn default() -> Self { let pk = CoordinatorPublicKey::zeroed(); let sk = CoordinatorSecretKey::zeroed(); let sum = 0.01_f64; let update = 0.1_f64; - let seed = RoundSeed::zeroed(); + let seed = vec![0_u8; 32]; let min_sum = 1_usize; let min_update = 3_usize; let phase = Phase::Idle; let sum_dict = SumDict::new(); let seed_dict = SeedDict::new(); let mask_dict = MaskDict::new(); - let model = None; - let masked_model = None; let events = VecDeque::new(); Self { pk, @@ -161,41 +120,31 @@ impl Default for Coordinator { sum_dict, seed_dict, mask_dict, - model, - masked_model, events, } } } -pub trait Coordinators: Sized { - define_trait_fields!( - pk, CoordinatorPublicKey; - sk, CoordinatorSecretKey; - sum, f64; - update, f64; - min_sum, usize; - min_update, usize; - seed, RoundSeed; - phase, Phase; - sum_dict, SumDict; - seed_dict, SeedDict; - mask_dict, MaskDict; - masked_model, Option; - events, VecDeque; - ); - +impl Coordinator { /// Create a coordinator. Fails if there is insufficient system entropy to generate secrets. - fn new() -> Result; + pub fn new() -> Result { + // crucial: init must be called before anything else in this module + sodiumoxide::init().or(Err(InitError))?; + let seed = randombytes(32); + Ok(Self { + seed, + ..Default::default() + }) + } - /// Emit an event. - fn emit_event(&mut self, event: ProtocolEvent) { - self.events_mut().push_back(event); + /// Emit an event + pub fn emit_event(&mut self, event: ProtocolEvent) { + self.events.push_back(event); } - /// Retrieve the next event. - fn next_event(&mut self) -> Option { - self.events_mut().pop_front() + /// Retrieve the next event + pub fn next_event(&mut self) -> Option { + self.events.pop_front() } fn message_open(&self) -> MessageOpen<'_, '_> { @@ -206,40 +155,61 @@ pub trait Coordinators: Sized { } /// Validate and handle a sum, update or sum2 message. - fn handle_message(&mut self, bytes: &[u8]) -> Result<(), PetError> { - match self.phase() { - Phase::Idle => Err(PetError::InvalidMessage), - Phase::Sum => self.handle_sum_message(bytes), - Phase::Update => self.handle_update_message(bytes), - Phase::Sum2 => self.handle_sum2_message(bytes), + pub fn handle_message(&mut self, bytes: &[u8]) -> Result<(), PetError> { + let message = self + .message_open() + .open(&bytes) + .map_err(|_| PetError::InvalidMessage)?; + let participant_pk = message.header.participant_pk; + match (self.phase, message.payload) { + (Phase::Sum, PayloadOwned::Sum(msg)) => self.handle_sum_message(participant_pk, msg), + (Phase::Update, PayloadOwned::Update(msg)) => { + self.handle_update_message(participant_pk, msg) + } + (Phase::Sum2, PayloadOwned::Sum2(msg)) => self.handle_sum2_message(participant_pk, msg), + _ => Err(PetError::InvalidMessage), } } /// Validate and handle a sum message. - fn handle_sum_message(&mut self, bytes: &[u8]) -> Result<(), PetError> { - let msg = SumMessage::open(bytes, self.pk(), self.sk())?; - msg.certificate().validate()?; - self.validate_sum_task(msg.sum_signature(), msg.pk())?; - self.add_sum_participant(msg.pk(), msg.ephm_pk())?; + fn handle_sum_message( + &mut self, + pk: ParticipantPublicKey, + message: SumOwned, + ) -> Result<(), PetError> { + self.validate_sum_task(&pk, &message.sum_signature)?; + self.sum_dict.insert(pk, message.ephm_pk); Ok(()) } /// Validate and handle an update message. - fn handle_update_message(&mut self, bytes: &[u8]) -> Result<(), PetError> { - let msg = UpdateMessage::open(bytes, self.pk(), self.sk())?; - msg.certificate().validate()?; - self.validate_update_task(msg.sum_signature(), msg.update_signature(), msg.pk())?; - self.aggregate_masked_model(msg.masked_model())?; - self.add_local_seed_dict(msg.pk(), msg.local_seed_dict())?; + fn handle_update_message( + &mut self, + pk: ParticipantPublicKey, + message: UpdateOwned, + ) -> Result<(), PetError> { + let UpdateOwned { + sum_signature, + update_signature, + local_seed_dict, + .. + } = message; + self.validate_update_task(&pk, &sum_signature, &update_signature)?; + self.add_local_seed_dict(&pk, &local_seed_dict)?; Ok(()) } /// Validate and handle a sum2 message. - fn handle_sum2_message(&mut self, bytes: &[u8]) -> Result<(), PetError> { - let msg = Sum2Message::open(bytes, self.pk(), self.sk())?; - msg.certificate().validate()?; - self.validate_sum_task(msg.sum_signature(), msg.pk())?; - self.add_mask(msg.pk(), msg.mask())?; + fn handle_sum2_message( + &mut self, + pk: ParticipantPublicKey, + message: Sum2Owned, + ) -> Result<(), PetError> { + if !self.sum_dict.contains_key(&pk) { + return Err(PetError::InvalidMessage); + } + self.validate_sum_task(&pk, &message.sum_signature)?; + self.add_mask(&pk, message.mask).unwrap(); Ok(()) } @@ -249,8 +219,8 @@ pub trait Coordinators: Sized { pk: &SumParticipantPublicKey, sum_signature: &ParticipantTaskSignature, ) -> Result<(), PetError> { - if pk.verify_detached(sum_signature, &[self.seed().as_slice(), b"sum"].concat()) - && sum_signature.is_eligible(*self.sum()) + if pk.verify_detached(sum_signature, &[self.seed.as_slice(), b"sum"].concat()) + && is_eligible(sum_signature, self.sum) { Ok(()) } else { @@ -265,13 +235,13 @@ pub trait Coordinators: Sized { sum_signature: &ParticipantTaskSignature, update_signature: &ParticipantTaskSignature, ) -> Result<(), PetError> { - if pk.verify_detached(sum_signature, &[self.seed().as_slice(), b"sum"].concat()) + if pk.verify_detached(sum_signature, &[self.seed.as_slice(), b"sum"].concat()) && pk.verify_detached( update_signature, - &[self.seed().as_slice(), b"update"].concat(), + &[self.seed.as_slice(), b"update"].concat(), ) - && !sum_signature.is_eligible(*self.sum()) - && update_signature.is_eligible(*self.update()) + && !is_eligible(sum_signature, self.sum) + && is_eligible(update_signature, self.update) { Ok(()) } else { @@ -279,65 +249,31 @@ pub trait Coordinators: Sized { } } - /// Add a sum participant to the sum dictionary. Fails if it is a repetition. - fn add_sum_participant( - &mut self, - pk: &SumParticipantPublicKey, - ephm_pk: &SumParticipantEphemeralPublicKey, - ) -> Result<(), PetError> { - if !self.sum_dict().contains_key(pk) { - self.sum_dict_mut().insert(*pk, *ephm_pk); - Ok(()) - } else { - Err(PetError::InvalidMessage) - } - } - /// Freeze the sum dictionary. fn freeze_sum_dict(&mut self) { - *self.seed_dict_mut() = self - .sum_dict() + self.seed_dict = self + .sum_dict .keys() .map(|pk| (*pk, LocalSeedDict::new())) .collect(); } - /// Aggregate a local masked model to the global masked model. Fails if the model types don't - /// conform. - fn aggregate_masked_model(&mut self, local_masked_model: &MaskedModel) -> Result<(), PetError> { - *self.masked_model_mut() = if let Some(global_masked_model) = self.masked_model() { - Some( - global_masked_model - .aggregate(local_masked_model) - .or(Err(PetError::InvalidMessage))?, - ) - } else { - Some(local_masked_model.clone()) - }; - Ok(()) - } - - /// Add a local seed dictionary to the seed dictionary. Fails if it contains invalid keys or it - /// is a repetition. + /// Add a local seed dictionary to the seed dictionary. Fails if it contains invalid keys. fn add_local_seed_dict( &mut self, pk: &UpdateParticipantPublicKey, local_seed_dict: &LocalSeedDict, ) -> Result<(), PetError> { - if local_seed_dict.keys().len() == self.sum_dict().keys().len() + if local_seed_dict.keys().len() == self.sum_dict.keys().len() && local_seed_dict .keys() - .all(|pk| self.sum_dict().contains_key(pk)) - && self - .seed_dict() - .values() - .next() - .map_or(true, |dict| !dict.contains_key(pk)) + .all(|pk| self.sum_dict.contains_key(pk)) { for (sum_pk, seed) in local_seed_dict { - self.seed_dict_mut() + // safe unwrap: existence of `sum_pk` is guaranteed by `freeze_sum_dict()` + self.seed_dict .get_mut(sum_pk) - .ok_or(PetError::InvalidMessage)? + .unwrap() .insert(*pk, seed.clone()); } Ok(()) @@ -348,50 +284,57 @@ pub trait Coordinators: Sized { /// Add a mask to the mask dictionary. Fails if the sum participant didn't register in the sum /// phase or it is a repetition. - fn add_mask(&mut self, pk: &SumParticipantPublicKey, mask: &Mask) -> Result<(), PetError> { - if self.sum_dict_mut().remove(pk).is_none() { - Err(PetError::InvalidMessage) - } else if let Some(count) = self.mask_dict_mut().get_mut(mask) { + fn add_mask(&mut self, pk: &SumParticipantPublicKey, mask: Mask) -> Result<(), PetError> { + // We move the participant key here to make sure a participant + // cannot submit a mask multiple times + if self.sum_dict.remove(pk).is_none() { + return Err(PetError::InvalidMessage); + } + + if let Some(count) = self.mask_dict.get_mut(&mask) { *count += 1; - Ok(()) } else { - self.mask_dict_mut().insert(mask.clone(), 1); - Ok(()) + self.mask_dict.insert(mask, 1); } + + Ok(()) } /// Freeze the mask dictionary. - fn freeze_mask_dict(&self) -> Result<&Mask, RoundFailed> { - if self.mask_dict().is_empty() { - Err(RoundFailed::NoMask) - } else { - let (mask, _) = self.mask_dict().iter().fold( + fn freeze_mask_dict(&mut self) -> Result { + if self.mask_dict.is_empty() { + return Err(RoundFailed::NoMask); + } + + self.mask_dict + .drain() + .fold( (None, 0_usize), - |(unique_mask, unique_count), (mask, count)| match unique_count.cmp(count) { - Ordering::Less => (Some(mask), *count), + |(unique_mask, unique_count), (mask, count)| match unique_count.cmp(&count) { + Ordering::Less => (Some(mask), count), Ordering::Greater => (unique_mask, unique_count), Ordering::Equal => (None, unique_count), }, - ); - mask.ok_or(RoundFailed::AmbiguousMasks) - } + ) + .0 + .ok_or(RoundFailed::AmbiguousMasks) } /// Clear the round dictionaries. fn clear_round_dicts(&mut self) { - self.sum_dict_mut().clear(); - self.sum_dict_mut().shrink_to_fit(); - self.seed_dict_mut().clear(); - self.seed_dict_mut().shrink_to_fit(); - self.mask_dict_mut().clear(); - self.mask_dict_mut().shrink_to_fit(); + self.sum_dict.clear(); + self.sum_dict.shrink_to_fit(); + self.seed_dict.clear(); + self.seed_dict.shrink_to_fit(); + self.mask_dict.clear(); + self.mask_dict.shrink_to_fit(); } /// Generate fresh round credentials. fn gen_round_keypair(&mut self) { let (pk, sk) = generate_encrypt_key_pair(); - *self.pk_mut() = pk; - *self.sk_mut() = sk; + self.pk = pk; + self.sk = sk; } /// Update the round threshold parameters (dummy). @@ -400,48 +343,75 @@ pub trait Coordinators: Sized { /// Update the seed round parameter. fn update_round_seed(&mut self) { // safe unwrap: `sk` and `seed` have same number of bytes - let (_, sk) = - SigningKeySeed::from_slice_unchecked(self.sk().as_slice()).derive_signing_key_pair(); + let (_, sk) = SigningKeySeed::from_slice(self.sk.as_slice()) + .unwrap() + .derive_signing_key_pair(); let signature = sk.sign_detached( &[ - self.seed().as_slice(), - &self.sum().to_le_bytes(), - &self.update().to_le_bytes(), + self.seed.as_slice(), + &self.sum.to_le_bytes(), + &self.update.to_le_bytes(), ] .concat(), ); - // safe unwrap: length of slice is guaranteed by constants - *self.seed_mut() = - RoundSeed::from_slice_unchecked(sha256::hash(signature.as_slice()).as_ref()); + self.seed = sha256::hash(signature.as_slice()).as_ref().to_vec(); + } + + /// Transition to the next phase if the protocol conditions are satisfied. + pub fn try_phase_transition(&mut self) { + match self.phase { + Phase::Idle => { + self.proceed_sum_phase(); + self.try_phase_transition(); + } + Phase::Sum => { + if self.has_enough_sums() { + self.proceed_update_phase(); + self.try_phase_transition(); + } + } + Phase::Update => { + if self.has_enough_seeds() { + self.proceed_sum2_phase(); + self.try_phase_transition(); + } + } + Phase::Sum2 => { + if self.has_enough_masks() { + self.proceed_idle_phase(); + self.try_phase_transition(); + } + } + } } /// Check whether enough sum participants submitted their ephemeral keys to start the update /// phase. fn has_enough_sums(&self) -> bool { - self.sum_dict().len() >= *self.min_sum() + self.sum_dict.len() >= self.min_sum } /// Check whether enough update participants submitted their models and seeds to start the sum2 /// phase. fn has_enough_seeds(&self) -> bool { - self.seed_dict() + self.seed_dict .values() .next() - .map(|dict| dict.len() >= *self.min_update()) + .map(|dict| dict.len() >= self.min_update) .unwrap_or(false) } /// Check whether enough sum participants submitted their masks to start the idle phase. fn has_enough_masks(&self) -> bool { - let mask_count = self.mask_dict().values().sum::(); - mask_count >= *self.min_sum() + let mask_count = self.mask_dict.values().sum::(); + mask_count >= self.min_sum } /// End the idle phase and proceed to the sum phase to start the round. fn proceed_sum_phase(&mut self) { info!("going to sum phase"); self.gen_round_keypair(); - *self.phase_mut() = Phase::Sum; + self.phase = Phase::Sum; self.emit_event(ProtocolEvent::StartSum(self.round_parameters())); } @@ -449,180 +419,48 @@ pub trait Coordinators: Sized { fn proceed_update_phase(&mut self) { info!("going to update phase"); self.freeze_sum_dict(); - *self.phase_mut() = Phase::Update; - self.emit_event(ProtocolEvent::StartUpdate(self.sum_dict().clone())); + self.phase = Phase::Update; + self.emit_event(ProtocolEvent::StartUpdate(self.sum_dict.clone())); } /// End the update phase and proceed to the sum2 phase. fn proceed_sum2_phase(&mut self) { info!("going to sum2 phase"); - *self.phase_mut() = Phase::Sum2; - self.emit_event(ProtocolEvent::StartSum2(self.seed_dict().clone())); - } - - /// Prepare the coordinator for a new round and go back to the initial phase. - fn start_new_round(&mut self) { - self.clear_round_dicts(); - self.update_round_thresholds(); - self.update_round_seed(); - *self.phase_mut() = Phase::Idle; - } - - fn round_parameters(&self) -> RoundParameters { - RoundParameters { - pk: *self.pk(), - sum: *self.sum(), - update: *self.update(), - seed: self.seed().clone(), - } - } -} - -pub trait MaskCoordinators: Coordinators { - define_trait_fields!(model, Option>); - - /// Unmask the masked model with a mask. - fn unmask_model(&self, mask: &Mask) -> Result, RoundFailed>; - - /// Transition to the next phase if the protocol conditions are satisfied. - fn try_phase_transition(&mut self) { - match self.phase() { - Phase::Idle => { - self.proceed_sum_phase(); - self.try_phase_transition(); - } - Phase::Sum => { - if self.has_enough_sums() { - self.proceed_update_phase(); - self.try_phase_transition(); - } - } - Phase::Update => { - if self.has_enough_seeds() { - self.proceed_sum2_phase(); - self.try_phase_transition(); - } - } - Phase::Sum2 => { - if self.has_enough_masks() { - self.proceed_idle_phase(); - self.try_phase_transition(); - } - } - } + self.phase = Phase::Sum2; + self.emit_event(ProtocolEvent::StartSum2(self.seed_dict.clone())); } /// End the sum2 phase and proceed to the idle phase to end the round. fn proceed_idle_phase(&mut self) { info!("going to idle phase"); - let outcome = if let Ok(mask) = self.freeze_mask_dict() { - if let Ok(model) = self.unmask_model(mask) { - *self.model_mut() = Some(model); - Some(()) - } else { - None - } - } else { - None - }; + let outcome = self.freeze_mask_dict().ok(); self.emit_event(ProtocolEvent::EndRound(outcome)); self.start_new_round(); } - /// Cancel the current round and restart a new one. - fn reset(&mut self) { - self.events_mut().clear(); + /// Cancel the current round and restart a new one + pub fn reset(&mut self) { + self.events.clear(); self.emit_event(ProtocolEvent::EndRound(None)); self.start_new_round(); self.try_phase_transition(); } -} - -impl Coordinators for Coordinator { - derive_trait_fields!( - pk, CoordinatorPublicKey; - sk, CoordinatorSecretKey; - sum, f64; - update, f64; - min_sum, usize; - min_update, usize; - seed, RoundSeed; - phase, Phase; - sum_dict, SumDict; - seed_dict, SeedDict; - mask_dict, MaskDict; - masked_model, Option; - events, VecDeque; - ); - /// Create a coordinator. Fails if there is insufficient system entropy to generate secrets. - fn new() -> Result { - // crucial: init must be called before anything else in this module - sodiumoxide::init().or(Err(InitError))?; - let seed = RoundSeed::generate(); - Ok(Self { - seed, - ..Default::default() - }) - } -} - -impl MaskCoordinators for Coordinator { - derive_trait_fields!(model, Option>); - - fn unmask_model(&self, mask: &Mask) -> Result, RoundFailed> { - let no_models = self.seed_dict.values().next().map_or(0, |dict| dict.len()); - if let Some(masked_model) = self.masked_model() { - masked_model - .unmask(mask, no_models) - .or(Err(RoundFailed::NoModel)) - } else { - Err(RoundFailed::NoModel) - } - } -} - -impl MaskCoordinators for Coordinator { - derive_trait_fields!(model, Option>); - - fn unmask_model(&self, mask: &Mask) -> Result, RoundFailed> { - let no_models = self.seed_dict.values().next().map_or(0, |dict| dict.len()); - if let Some(masked_model) = self.masked_model() { - masked_model - .unmask(mask, no_models) - .or(Err(RoundFailed::NoModel)) - } else { - Err(RoundFailed::NoModel) - } - } -} - -impl MaskCoordinators for Coordinator { - derive_trait_fields!(model, Option>); - - fn unmask_model(&self, mask: &Mask) -> Result, RoundFailed> { - let no_models = self.seed_dict.values().next().map_or(0, |dict| dict.len()); - if let Some(masked_model) = self.masked_model() { - masked_model - .unmask(mask, no_models) - .or(Err(RoundFailed::NoModel)) - } else { - Err(RoundFailed::NoModel) - } + /// Prepare the coordinator for a new round and go back to the + /// initial phase + fn start_new_round(&mut self) { + self.clear_round_dicts(); + self.update_round_thresholds(); + self.update_round_seed(); + self.phase = Phase::Idle; } -} -impl MaskCoordinators for Coordinator { - derive_trait_fields!(model, Option>); - - fn unmask_model(&self, mask: &Mask) -> Result, RoundFailed> { - let no_models = self.seed_dict.values().next().map_or(0, |dict| dict.len()); - if let Some(masked_model) = self.masked_model() { - masked_model - .unmask(mask, no_models) - .or(Err(RoundFailed::NoModel)) - } else { - Err(RoundFailed::NoModel) + pub fn round_parameters(&self) -> RoundParameters { + RoundParameters { + pk: self.pk, + sum: self.sum, + update: self.update, + seed: self.seed.clone(), } } } @@ -639,75 +477,60 @@ pub struct RoundParameters { pub update: f64, /// The random round seed. - pub seed: RoundSeed, + pub seed: Vec, } #[cfg(test)] mod tests { - use std::iter; - - use num::{bigint::BigUint, traits::identities::Zero}; - use super::*; use crate::{ crypto::*, - mask::{ - config::{BoundType, DataType, GroupType, MaskConfigs, ModelType}, - seed::MaskSeed, - }, + mask::{Mask, MaskSeed}, }; #[test] fn test_coordinator() { - let coord = Coordinator::::new().unwrap(); + let coord = Coordinator::new().unwrap(); assert_eq!(coord.pk, PublicEncryptKey::zeroed()); assert_eq!(coord.sk, SecretEncryptKey::zeroed()); assert!(coord.sum >= 0. && coord.sum <= 1.); assert!(coord.update >= 0. && coord.update <= 1.); - assert_eq!(coord.seed.as_slice().len(), 32); + assert_eq!(coord.seed.len(), 32); assert!(coord.min_sum >= 1); assert!(coord.min_update >= 3); assert_eq!(coord.phase, Phase::Idle); assert_eq!(coord.sum_dict, SumDict::new()); assert_eq!(coord.seed_dict, SeedDict::new()); assert_eq!(coord.mask_dict, MaskDict::new()); - assert_eq!(coord.model, None); - assert_eq!(coord.masked_model, None); } #[test] fn test_validate_sum_task() { - let mut coord = Coordinator::::new().unwrap(); - coord.sum = 0.5_f64; - coord.update = 0.5_f64; - coord.seed = RoundSeed::from_slice_unchecked(&[ + let mut coord = Coordinator::new().unwrap(); + coord.seed = vec![ 229, 16, 164, 40, 138, 161, 23, 161, 175, 102, 13, 103, 229, 229, 163, 56, 184, 250, 190, 44, 91, 69, 246, 222, 64, 101, 139, 22, 126, 6, 103, 238, - ]); - - // eligible sum signature + ]; let sum_signature = Signature::from_slice_unchecked(&[ - 216, 122, 81, 56, 190, 176, 44, 37, 167, 89, 45, 93, 82, 92, 147, 208, 158, 65, 145, - 253, 121, 35, 80, 38, 4, 37, 65, 244, 185, 101, 59, 124, 21, 22, 184, 234, 226, 78, - 255, 85, 112, 206, 76, 140, 216, 39, 172, 76, 0, 172, 239, 189, 106, 64, 137, 185, 123, - 132, 115, 14, 160, 116, 82, 7, + 106, 152, 91, 255, 122, 191, 159, 252, 180, 225, 105, 182, 30, 16, 99, 187, 220, 139, + 88, 105, 112, 224, 167, 249, 76, 12, 108, 182, 144, 208, 55, 80, 191, 47, 246, 87, 213, + 158, 237, 197, 199, 181, 91, 232, 197, 136, 230, 155, 56, 106, 217, 129, 200, 31, 113, + 254, 148, 234, 134, 152, 173, 69, 51, 13, ]); let pk = PublicSigningKey::from_slice_unchecked(&[ - 76, 128, 23, 65, 195, 57, 190, 223, 67, 224, 102, 139, 140, 90, 67, 160, 106, 181, 7, - 196, 245, 56, 193, 51, 15, 212, 9, 153, 61, 152, 173, 165, + 130, 93, 138, 240, 229, 140, 60, 97, 160, 189, 208, 185, 248, 206, 146, 160, 53, 173, + 146, 163, 35, 233, 191, 177, 72, 121, 136, 23, 32, 241, 181, 165, ]); - assert_eq!(coord.validate_sum_task(&sum_signature, &pk).unwrap(), ()); - - // ineligible sum signature + assert_eq!(coord.validate_sum_task(&pk, &sum_signature).unwrap(), ()); let sum_signature = Signature::from_slice_unchecked(&[ - 75, 17, 216, 121, 214, 15, 222, 250, 0, 172, 158, 190, 201, 132, 251, 15, 149, 4, 127, - 110, 214, 208, 17, 93, 236, 103, 199, 193, 74, 224, 243, 79, 217, 237, 184, 104, 126, - 203, 18, 189, 248, 237, 116, 163, 42, 32, 236, 96, 181, 151, 144, 252, 211, 56, 141, - 98, 108, 248, 231, 248, 61, 200, 184, 13, + 237, 143, 229, 127, 38, 65, 45, 145, 131, 233, 178, 250, 81, 211, 224, 103, 236, 91, + 82, 56, 19, 186, 236, 134, 19, 124, 16, 54, 148, 121, 206, 31, 71, 2, 11, 90, 41, 183, + 56, 58, 216, 3, 199, 181, 195, 118, 43, 185, 173, 25, 62, 186, 146, 14, 147, 24, 14, + 191, 118, 202, 185, 124, 125, 9, ]); let pk = PublicSigningKey::from_slice_unchecked(&[ - 200, 198, 194, 36, 111, 82, 127, 148, 245, 223, 158, 98, 142, 50, 65, 51, 7, 234, 201, - 148, 45, 56, 85, 65, 75, 128, 178, 175, 101, 93, 241, 162, + 121, 99, 230, 84, 169, 21, 227, 76, 114, 4, 61, 21, 68, 153, 79, 43, 111, 201, 28, 152, + 111, 145, 208, 17, 156, 93, 67, 74, 56, 40, 202, 149, ]); assert_eq!( coord.validate_sum_task(&pk, &sum_signature).unwrap_err(), @@ -717,30 +540,26 @@ mod tests { #[test] fn test_validate_update_task() { - let mut coord = Coordinator::::new().unwrap(); - coord.sum = 0.5_f64; - coord.update = 0.5_f64; - coord.seed = RoundSeed::from_slice_unchecked(&[ + let mut coord = Coordinator::new().unwrap(); + coord.seed = vec![ 229, 16, 164, 40, 138, 161, 23, 161, 175, 102, 13, 103, 229, 229, 163, 56, 184, 250, 190, 44, 91, 69, 246, 222, 64, 101, 139, 22, 126, 6, 103, 238, - ]); - - // ineligible sum signature and eligible update signature + ]; let sum_signature = Signature::from_slice_unchecked(&[ - 206, 154, 228, 165, 240, 196, 64, 106, 135, 124, 140, 83, 15, 188, 229, 78, 38, 34, - 254, 241, 7, 23, 44, 147, 6, 195, 158, 227, 250, 159, 60, 214, 42, 103, 145, 69, 121, - 165, 115, 196, 120, 164, 108, 200, 114, 200, 16, 21, 208, 233, 83, 176, 70, 77, 64, - 141, 65, 63, 236, 184, 250, 127, 59, 8, + 184, 138, 175, 209, 149, 211, 214, 237, 125, 97, 56, 97, 206, 13, 111, 107, 227, 146, + 40, 41, 210, 179, 5, 83, 113, 185, 6, 3, 221, 135, 128, 74, 20, 120, 102, 182, 16, 138, + 58, 94, 7, 128, 151, 50, 10, 107, 253, 73, 126, 36, 244, 141, 254, 34, 113, 71, 196, + 127, 18, 96, 223, 176, 67, 10, ]); let update_signature = Signature::from_slice_unchecked(&[ - 76, 195, 29, 117, 72, 226, 246, 103, 166, 245, 16, 122, 235, 107, 96, 111, 149, 231, - 216, 62, 1, 206, 139, 127, 208, 254, 118, 43, 0, 193, 54, 40, 2, 144, 240, 162, 240, - 226, 223, 0, 228, 59, 13, 252, 42, 34, 16, 22, 202, 30, 166, 138, 231, 2, 125, 123, 75, - 146, 103, 149, 95, 7, 177, 15, + 71, 51, 166, 220, 84, 170, 245, 60, 139, 79, 238, 74, 172, 122, 130, 47, 188, 168, 114, + 237, 210, 210, 234, 7, 123, 88, 73, 173, 174, 187, 82, 140, 41, 6, 44, 202, 255, 180, + 36, 186, 170, 97, 164, 155, 93, 21, 136, 114, 208, 246, 158, 254, 242, 12, 217, 148, + 27, 206, 44, 52, 204, 55, 4, 13, ]); let pk = PublicSigningKey::from_slice_unchecked(&[ - 220, 150, 230, 193, 226, 222, 50, 73, 44, 227, 70, 25, 58, 237, 34, 184, 151, 253, 127, - 252, 13, 23, 135, 194, 244, 12, 139, 17, 34, 61, 9, 92, + 106, 233, 139, 112, 104, 250, 253, 242, 74, 19, 188, 176, 211, 198, 17, 98, 132, 9, + 220, 253, 191, 119, 159, 138, 134, 250, 244, 193, 58, 244, 218, 231, ]); assert_eq!( coord @@ -748,23 +567,21 @@ mod tests { .unwrap(), (), ); - - // ineligible sum signature and ineligible update signature let sum_signature = Signature::from_slice_unchecked(&[ - 73, 255, 75, 96, 89, 197, 182, 203, 156, 41, 231, 88, 103, 16, 204, 35, 52, 165, 178, - 159, 33, 199, 112, 59, 203, 58, 243, 229, 190, 226, 168, 96, 146, 49, 79, 147, 224, - 235, 140, 247, 101, 99, 255, 179, 150, 219, 84, 69, 146, 49, 182, 105, 42, 65, 159, 41, - 118, 214, 172, 240, 213, 27, 192, 12, + 136, 94, 175, 83, 39, 171, 196, 102, 225, 111, 39, 28, 104, 51, 34, 117, 112, 178, 165, + 134, 128, 184, 131, 67, 73, 244, 98, 0, 133, 12, 111, 60, 215, 19, 237, 197, 96, 110, + 27, 196, 205, 3, 201, 112, 30, 24, 109, 145, 30, 62, 169, 130, 113, 35, 253, 194, 148, + 111, 151, 203, 238, 109, 223, 13, ]); let update_signature = Signature::from_slice_unchecked(&[ - 163, 180, 225, 224, 231, 2, 162, 183, 211, 242, 26, 56, 124, 179, 241, 13, 105, 29, - 240, 251, 89, 126, 147, 229, 138, 68, 118, 206, 102, 193, 209, 79, 219, 109, 87, 59, - 197, 177, 197, 213, 79, 143, 149, 66, 159, 107, 139, 244, 6, 224, 111, 175, 90, 213, - 206, 143, 152, 0, 21, 15, 102, 74, 15, 14, + 189, 170, 55, 119, 59, 71, 14, 211, 117, 167, 110, 79, 44, 160, 171, 199, 43, 77, 147, + 65, 121, 172, 77, 248, 81, 62, 66, 111, 235, 209, 131, 188, 5, 117, 123, 81, 204, 136, + 205, 213, 28, 248, 46, 39, 83, 80, 66, 3, 77, 224, 60, 248, 231, 216, 241, 224, 87, + 170, 120, 214, 43, 106, 188, 13, ]); let pk = PublicSigningKey::from_slice_unchecked(&[ - 109, 181, 253, 91, 247, 2, 201, 224, 161, 207, 128, 48, 16, 201, 86, 14, 193, 204, 49, - 88, 9, 170, 109, 120, 245, 0, 208, 129, 107, 213, 253, 72, + 221, 242, 188, 27, 163, 226, 152, 164, 43, 89, 154, 78, 26, 54, 35, 233, 129, 245, 131, + 251, 251, 154, 171, 121, 207, 58, 134, 201, 185, 31, 80, 181, ]); assert_eq!( coord @@ -772,23 +589,21 @@ mod tests { .unwrap_err(), PetError::InvalidMessage, ); - - // eligible sum signature and eligible update signature let sum_signature = Signature::from_slice_unchecked(&[ - 22, 28, 85, 58, 83, 51, 179, 43, 142, 58, 15, 113, 125, 191, 145, 179, 22, 216, 183, - 114, 230, 219, 151, 4, 213, 187, 197, 160, 171, 240, 40, 0, 133, 132, 7, 117, 105, 37, - 84, 214, 243, 19, 187, 132, 80, 194, 214, 204, 58, 130, 33, 63, 40, 149, 30, 27, 106, - 122, 254, 106, 161, 61, 176, 5, + 70, 46, 99, 192, 150, 169, 206, 133, 91, 206, 219, 205, 228, 255, 57, 96, 186, 64, 63, + 79, 109, 112, 192, 225, 238, 41, 5, 27, 213, 91, 83, 60, 219, 81, 227, 101, 30, 12, 36, + 87, 37, 57, 64, 184, 146, 129, 217, 215, 212, 43, 77, 255, 202, 93, 150, 25, 147, 50, + 63, 93, 8, 83, 33, 14, ]); let update_signature = Signature::from_slice_unchecked(&[ - 7, 50, 23, 176, 28, 214, 185, 141, 131, 236, 166, 140, 232, 21, 223, 88, 16, 98, 202, - 232, 46, 210, 102, 177, 107, 196, 87, 192, 36, 153, 175, 104, 208, 61, 179, 151, 191, - 103, 75, 70, 109, 185, 10, 215, 28, 29, 12, 68, 15, 124, 248, 159, 57, 84, 156, 83, - 189, 233, 8, 184, 197, 21, 51, 1, + 222, 204, 229, 157, 200, 187, 57, 66, 40, 158, 76, 184, 105, 1, 221, 122, 119, 110, + 115, 98, 119, 189, 130, 222, 8, 83, 69, 80, 107, 230, 18, 58, 180, 198, 160, 115, 111, + 173, 147, 182, 89, 197, 14, 138, 199, 64, 28, 34, 51, 98, 32, 219, 138, 252, 133, 139, + 219, 212, 207, 133, 61, 79, 200, 7, ]); let pk = PublicSigningKey::from_slice_unchecked(&[ - 212, 224, 51, 239, 70, 208, 166, 236, 81, 5, 7, 226, 54, 151, 50, 223, 133, 134, 66, - 167, 32, 226, 141, 200, 232, 41, 112, 144, 79, 135, 207, 87, + 63, 238, 181, 248, 155, 69, 222, 175, 198, 46, 148, 78, 39, 51, 249, 250, 45, 157, 92, + 1, 18, 43, 24, 199, 144, 235, 245, 85, 63, 225, 151, 120, ]); assert_eq!( coord @@ -796,23 +611,21 @@ mod tests { .unwrap_err(), PetError::InvalidMessage, ); - - // eligible sum signature and ineligible update signature let sum_signature = Signature::from_slice_unchecked(&[ - 176, 1, 85, 13, 43, 110, 122, 206, 186, 247, 44, 215, 154, 222, 34, 34, 173, 139, 166, - 42, 239, 160, 167, 126, 72, 234, 114, 1, 236, 10, 210, 155, 170, 33, 138, 129, 178, 56, - 154, 228, 84, 174, 187, 242, 3, 224, 143, 102, 134, 47, 49, 33, 103, 107, 147, 51, 36, - 143, 215, 134, 213, 162, 255, 5, + 186, 136, 94, 177, 248, 84, 83, 97, 83, 183, 242, 20, 93, 90, 21, 159, 238, 90, 82, + 254, 87, 74, 53, 23, 199, 27, 224, 156, 113, 252, 66, 90, 167, 109, 166, 89, 80, 96, + 216, 227, 177, 218, 216, 59, 239, 169, 132, 33, 91, 108, 26, 163, 159, 233, 34, 208, 7, + 19, 106, 175, 193, 253, 47, 14, ]); let update_signature = Signature::from_slice_unchecked(&[ - 39, 29, 201, 153, 218, 79, 161, 208, 151, 222, 220, 95, 118, 156, 17, 49, 35, 125, 243, - 214, 83, 240, 196, 168, 166, 225, 86, 103, 140, 237, 252, 196, 11, 5, 85, 18, 126, 210, - 82, 14, 88, 198, 114, 39, 239, 226, 243, 28, 48, 22, 39, 19, 244, 103, 13, 92, 216, - 251, 155, 154, 180, 114, 158, 13, + 146, 127, 108, 132, 170, 89, 77, 240, 50, 81, 109, 30, 120, 212, 65, 155, 132, 147, + 199, 86, 136, 204, 184, 14, 162, 107, 45, 215, 73, 129, 214, 79, 160, 249, 118, 47, + 116, 140, 91, 200, 226, 203, 166, 35, 54, 24, 148, 124, 113, 154, 131, 141, 122, 25, + 26, 224, 175, 60, 221, 27, 252, 234, 245, 15, ]); let pk = PublicSigningKey::from_slice_unchecked(&[ - 251, 251, 252, 131, 93, 84, 116, 191, 88, 135, 45, 43, 201, 66, 7, 236, 40, 74, 17, 11, - 33, 126, 224, 127, 77, 232, 59, 34, 120, 174, 137, 2, + 147, 43, 34, 245, 84, 183, 114, 36, 243, 153, 91, 4, 75, 52, 247, 250, 86, 96, 127, + 106, 222, 191, 119, 72, 208, 88, 242, 40, 178, 151, 8, 7, ]); assert_eq!( coord @@ -825,8 +638,8 @@ mod tests { fn auxiliary_sum(min_sum: usize) -> SumDict { iter::repeat_with(|| { ( - PublicSigningKey::from_slice_unchecked(&randombytes(32)), - PublicEncryptKey::from_slice_unchecked(&randombytes(32)), + PublicSigningKey::from_slice(&randombytes(32)).unwrap(), + PublicEncryptKey::from_slice(&randombytes(32)).unwrap(), ) }) .take(min_sum) @@ -835,7 +648,7 @@ mod tests { #[test] fn test_sum_dict() { - let mut coord = Coordinator::::new().unwrap(); + let mut coord = Coordinator::new().unwrap(); coord.min_sum = 3; coord.min_update = 3; coord.try_phase_transition(); // start the sum phase @@ -856,7 +669,7 @@ mod tests { let sum_dict = auxiliary_sum(coord.min_sum); for (pk, ephm_pk) in sum_dict.iter() { assert!(!coord.has_enough_sums()); - coord.add_sum_participant(pk, ephm_pk).unwrap(); + coord.sum_dict.insert(*pk, *ephm_pk); } assert_eq!(coord.sum_dict, sum_dict); assert!(coord.seed_dict.is_empty()); @@ -879,11 +692,11 @@ mod tests { } fn generate_update(sum_dict: &SumDict) -> (UpdateParticipantPublicKey, LocalSeedDict) { - let seed = MaskSeed::generate(); - let pk = PublicSigningKey::from_slice_unchecked(&randombytes(32)); + let seed = MaskSeed::new(); + let pk = PublicSigningKey::from_slice(&randombytes(32)).unwrap(); let local_seed_dict = sum_dict .iter() - .map(|(sum_pk, sum_ephm_pk)| (*sum_pk, seed.encrypt(sum_ephm_pk))) + .map(|(sum_pk, sum_ephm_pk)| (*sum_pk, seed.seal(sum_ephm_pk))) .collect::(); (pk, local_seed_dict) } @@ -916,7 +729,7 @@ mod tests { #[test] fn test_seed_dict() { - let mut coord = Coordinator::::new().unwrap(); + let mut coord = Coordinator::new().unwrap(); coord.min_sum = 3; coord.min_update = 3; coord.try_phase_transition(); // start the sum phase @@ -951,62 +764,47 @@ mod tests { } fn auxiliary_mask(min_sum: usize) -> (Vec, MaskDict) { - let config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); + // this doesn't work for `min_sum == 0` and `min_sum == 2` let masks = [ - vec![MaskSeed::generate().derive_mask(10, &config); min_sum - 1], - vec![MaskSeed::generate().derive_mask(10, &config); 1], + vec![Mask::from(randombytes(32)); min_sum - 1], + vec![Mask::from(randombytes(32)); 1], ] .concat(); - let mask_dict = [ - (masks[0].clone(), min_sum - 1), - (masks[min_sum - 1].clone(), 1), - ] - .iter() - .cloned() - .collect::(); + let mask_dict = masks + .iter() + .map(|mask| Sha256::hash(mask.as_ref())) + .collect::(); (masks, mask_dict) } #[test] fn test_mask_dict() { - let mut coord = Coordinator::::new().unwrap(); + let mut coord = Coordinator::new().unwrap(); coord.min_sum = 3; coord.min_update = 3; coord.phase = Phase::Sum2; // Pretend we received enough masks - let sum_dict = auxiliary_sum(coord.min_sum); - coord.sum_dict = sum_dict.clone(); let (masks, mask_dict) = auxiliary_mask(coord.min_sum); - for (pk, mask) in sum_dict.keys().zip(masks.iter()) { - coord.add_mask(pk, mask).unwrap(); - } + coord + .mask_dict + .update(masks.iter().map(|mask| Sha256::hash(mask.as_ref()))); assert_eq!(coord.mask_dict, mask_dict); assert!(coord.has_enough_masks()); - assert_eq!(coord.freeze_mask_dict().unwrap(), &masks[0]); + assert_eq!( + coord.freeze_mask_dict().unwrap(), + mask_dict.most_common()[0].0, + ); } #[test] fn test_mask_dict_fail() { - let mut coord = Coordinator::::new().unwrap(); + let mut coord = Coordinator::new().unwrap(); coord.min_sum = 3; coord.min_update = 3; coord.phase = Phase::Sum2; - let config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); - coord.mask_dict = iter::repeat_with(|| (MaskSeed::generate().derive_mask(10, &config), 1)) + coord.mask_dict = iter::repeat_with(|| Sha256::hash(&randombytes(32))) .take(coord.min_sum) .collect::(); assert_eq!( @@ -1017,7 +815,7 @@ mod tests { #[test] fn test_clear_round_dicts() { - let mut coord = Coordinator::::new().unwrap(); + let mut coord = Coordinator::new().unwrap(); coord.clear_round_dicts(); assert!(coord.sum_dict.is_empty()); assert!(coord.seed_dict.is_empty()); @@ -1026,7 +824,7 @@ mod tests { #[test] fn test_gen_round_keypair() { - let mut coord = Coordinator::::new().unwrap(); + let mut coord = Coordinator::new().unwrap(); coord.gen_round_keypair(); assert_eq!(coord.pk, coord.sk.public_key()); assert_eq!(coord.sk.as_slice().len(), 32); @@ -1034,11 +832,11 @@ mod tests { #[test] fn test_update_round_seed() { - let mut coord = Coordinator::::new().unwrap(); - coord.seed = RoundSeed::from_slice_unchecked(&[ + let mut coord = Coordinator::new().unwrap(); + coord.seed = vec![ 229, 16, 164, 40, 138, 161, 23, 161, 175, 102, 13, 103, 229, 229, 163, 56, 184, 250, 190, 44, 91, 69, 246, 222, 64, 101, 139, 22, 126, 6, 103, 238, - ]); + ]; coord.sk = SecretEncryptKey::from_slice_unchecked(&[ 39, 177, 238, 71, 112, 48, 60, 73, 246, 28, 143, 222, 211, 114, 29, 34, 174, 28, 77, 51, 146, 27, 155, 224, 20, 169, 254, 164, 231, 141, 190, 31, @@ -1046,16 +844,16 @@ mod tests { coord.update_round_seed(); assert_eq!( coord.seed, - RoundSeed::from_slice_unchecked(&[ + vec![ 90, 35, 97, 78, 70, 149, 40, 131, 149, 211, 30, 236, 194, 175, 156, 76, 85, 43, - 138, 159, 180, 166, 25, 205, 156, 176, 3, 203, 27, 128, 231, 38, - ]), + 138, 159, 180, 166, 25, 205, 156, 176, 3, 203, 27, 128, 231, 38 + ], ); } #[test] fn test_transitions() { - let mut coord = Coordinator::::new().unwrap(); + let mut coord = Coordinator::new().unwrap(); coord.min_sum = 3; coord.min_update = 3; @@ -1120,21 +918,13 @@ mod tests { // Pretend we received enough masks and transition. This time // the state should change and we should restart a round - let integers = vec![BigUint::zero(); 10]; - let config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); - coord.masked_model = Some(MaskedModel::from_parts(integers, config).unwrap()); + let chosen_seed = mask_dict.most_common().into_iter().next().unwrap().0; coord.mask_dict = mask_dict; let seed = coord.seed.clone(); coord.try_phase_transition(); assert_eq!( coord.next_event().unwrap(), - ProtocolEvent::EndRound(Some(())) + ProtocolEvent::EndRound(Some(chosen_seed)) ); assert_eq!( coord.next_event().unwrap(), diff --git a/rust/src/crypto/mod.rs b/rust/src/crypto/mod.rs index b12d0d8ab..69b427485 100644 --- a/rust/src/crypto/mod.rs +++ b/rust/src/crypto/mod.rs @@ -3,12 +3,8 @@ mod encrypt; mod hash; +mod prng; mod sign; - -use num::{bigint::BigUint, traits::identities::Zero}; -use rand::RngCore; -use rand_chacha::ChaCha20Rng; - pub use self::{ encrypt::{ generate_encrypt_key_pair, @@ -17,6 +13,8 @@ pub use self::{ SecretEncryptKey, SEALBYTES, }, + hash::Sha256, + prng::generate_integer, sign::{ generate_signing_key_pair, PublicSigningKey, @@ -48,72 +46,3 @@ pub trait ByteObject: Sized { Self::from_slice(bytes).unwrap() } } - -/// Generate a secure pseudo-random integer. Draws from a uniform distribution over the integers -/// between zero (included) and `max_int` (excluded). -pub fn generate_integer(prng: &mut ChaCha20Rng, max_int: &BigUint) -> BigUint { - if max_int.is_zero() { - return BigUint::zero(); - } - let mut bytes = max_int.to_bytes_le(); - let mut rand_int = max_int.clone(); - while rand_int >= *max_int { - prng.fill_bytes(&mut bytes); - rand_int = BigUint::from_bytes_le(&bytes); - } - rand_int -} - -#[cfg(test)] -mod tests { - use num::traits::{pow::Pow, Num}; - use rand::SeedableRng; - - use super::*; - - #[test] - fn test_generate_integer() { - let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); - let max_int = BigUint::from(u128::max_value()).pow(2_usize); - assert_eq!( - generate_integer(&mut prng, &max_int), - BigUint::from_str_radix( - "90034050956742099321159087842304570510687605373623064829879336909608119744630", - 10 - ) - .unwrap() - ); - assert_eq!( - generate_integer(&mut prng, &max_int), - BigUint::from_str_radix( - "60790020689334235010238064028215988394112077193561636249125918224917556969946", - 10 - ) - .unwrap() - ); - assert_eq!( - generate_integer(&mut prng, &max_int), - BigUint::from_str_radix( - "107415344426328791036720294006773438815099086866510488084511304829720271980447", - 10 - ) - .unwrap() - ); - assert_eq!( - generate_integer(&mut prng, &max_int), - BigUint::from_str_radix( - "50343610553303623842889112417183549658912134525854625844144939347139411162921", - 10 - ) - .unwrap() - ); - assert_eq!( - generate_integer(&mut prng, &max_int), - BigUint::from_str_radix( - "42382469383990928111449714288937630103705168010724718767641573929365517895981", - 10 - ) - .unwrap() - ); - } -} diff --git a/rust/src/crypto/prng.rs b/rust/src/crypto/prng.rs new file mode 100644 index 000000000..6f68ac527 --- /dev/null +++ b/rust/src/crypto/prng.rs @@ -0,0 +1,73 @@ +use num::{bigint::BigUint, traits::identities::Zero}; +use rand::RngCore; +use rand_chacha::ChaCha20Rng; + +/// Generate a secure pseudo-random integer. Draws from a uniform +/// distribution over the integers between zero (included) and +/// `max_int` (excluded). +pub fn generate_integer(prng: &mut ChaCha20Rng, max_int: &BigUint) -> BigUint { + if max_int.is_zero() { + return BigUint::zero(); + } + let mut bytes = max_int.to_bytes_le(); + let mut rand_int = max_int.clone(); + while rand_int >= *max_int { + prng.fill_bytes(&mut bytes); + rand_int = BigUint::from_bytes_le(&bytes); + } + rand_int +} + +#[cfg(test)] +mod tests { + use num::traits::{pow::Pow, Num}; + use rand::SeedableRng; + + use super::*; + + #[test] + fn test_generate_integer() { + let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); + let max_int = BigUint::from(u128::max_value()).pow(2_usize); + assert_eq!( + generate_integer(&mut prng, &max_int), + BigUint::from_str_radix( + "90034050956742099321159087842304570510687605373623064829879336909608119744630", + 10 + ) + .unwrap() + ); + assert_eq!( + generate_integer(&mut prng, &max_int), + BigUint::from_str_radix( + "60790020689334235010238064028215988394112077193561636249125918224917556969946", + 10 + ) + .unwrap() + ); + assert_eq!( + generate_integer(&mut prng, &max_int), + BigUint::from_str_radix( + "107415344426328791036720294006773438815099086866510488084511304829720271980447", + 10 + ) + .unwrap() + ); + assert_eq!( + generate_integer(&mut prng, &max_int), + BigUint::from_str_radix( + "50343610553303623842889112417183549658912134525854625844144939347139411162921", + 10 + ) + .unwrap() + ); + assert_eq!( + generate_integer(&mut prng, &max_int), + BigUint::from_str_radix( + "42382469383990928111449714288937630103705168010724718767641573929365517895981", + 10 + ) + .unwrap() + ); + } +} diff --git a/rust/src/crypto/sign.rs b/rust/src/crypto/sign.rs index 14b50c366..009aeab45 100644 --- a/rust/src/crypto/sign.rs +++ b/rust/src/crypto/sign.rs @@ -1,10 +1,6 @@ use super::ByteObject; use derive_more::{AsMut, AsRef, From}; -use num::{ - bigint::{BigUint, ToBigInt}, - rational::Ratio, -}; -use sodiumoxide::crypto::{hash::sha256, sign}; +use sodiumoxide::crypto::sign; /// Generate a new random signing key pair pub fn generate_signing_key_pair() -> (PublicSigningKey, SecretSigningKey) { @@ -128,27 +124,6 @@ impl ByteObject for Signature { } } -impl Signature { - /// Compute the floating point representation of the hashed signature and ensure that it - /// is below the given threshold: int(hash(signature)) / (2**hashbits - 1) <= threshold. - pub fn is_eligible(&self, threshold: f64) -> bool { - if threshold < 0_f64 { - return false; - } else if threshold > 1_f64 { - return true; - } - // safe unwraps: `to_bigint` never fails for `BigUint`s - let numer = BigUint::from_bytes_le(sha256::hash(self.as_slice()).as_ref()) - .to_bigint() - .unwrap(); - let denom = BigUint::from_bytes_le([u8::MAX; sha256::DIGESTBYTES].as_ref()) - .to_bigint() - .unwrap(); - // safe unwrap: `threshold` is guaranteed to be finite - Ratio::new(numer, denom) <= Ratio::from_float(threshold).unwrap() - } -} - /// A seed that can be used for signing key pair generation. When /// `KeySeed` goes out of scope, its contents will be zeroed out. #[derive(AsRef, AsMut, From, Serialize, Deserialize, Eq, PartialEq, Clone)] @@ -178,29 +153,3 @@ impl ByteObject for SigningKeySeed { self.0.as_ref() } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_signature_is_eligible() { - // eligible signature - let sig = Signature::from_slice_unchecked(&[ - 172, 29, 85, 219, 118, 44, 107, 32, 219, 253, 25, 242, 53, 45, 111, 62, 102, 130, 24, - 8, 222, 199, 34, 120, 166, 163, 223, 229, 100, 50, 252, 244, 250, 88, 196, 151, 136, - 48, 39, 198, 166, 86, 29, 151, 13, 81, 69, 198, 40, 148, 134, 126, 7, 202, 1, 56, 174, - 43, 89, 28, 242, 194, 4, 214, - ]); - assert!(sig.is_eligible(0.5_f64)); - - // ineligible signature - let sig = Signature::from_slice_unchecked(&[ - 119, 2, 197, 174, 52, 165, 229, 22, 218, 210, 240, 188, 220, 232, 149, 129, 211, 13, - 61, 217, 186, 79, 102, 15, 109, 237, 83, 193, 12, 117, 210, 66, 99, 230, 30, 131, 63, - 108, 28, 222, 48, 92, 153, 71, 159, 220, 115, 181, 183, 155, 146, 182, 205, 89, 140, - 234, 100, 40, 199, 248, 23, 147, 172, 248, - ]); - assert!(!sig.is_eligible(0.5_f64)); - } -} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 954fc7591..c860810a7 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,46 +1,33 @@ #![allow(dead_code)] -#![allow(unused_imports)] -#![feature(or_patterns)] -#![feature(const_fn)] -#![feature(stmt_expr_attributes)] -#![feature(option_unwrap_none)] +#![feature(bool_to_option)] #[macro_use] -extern crate tracing; +extern crate serde; + #[macro_use] extern crate tracing; -#[macro_use] -mod macros; pub mod certificate; pub mod coordinator; -pub mod crypto; pub mod mask; pub mod message; -pub mod model; pub mod participant; pub mod service; pub mod utils; use std::collections::HashMap; -use crypto::{PublicEncryptKey, PublicSigningKey, SecretEncryptKey, SecretSigningKey, Signature}; -use thiserror::Error; - -use crate::mask::seed::EncryptedMaskSeed; - -#[derive(Error, Debug)] -#[error("initialization failed: insufficient system entropy to generate secrets")] -pub struct InitError; - -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq)] /// PET protocol errors. pub enum PetError { InvalidMessage, - InvalidModel, InvalidMask, + InvalidModel, } +pub mod crypto; +use crypto::{PublicEncryptKey, PublicSigningKey, SecretEncryptKey, SecretSigningKey, Signature}; + /// A public encryption key that identifies a coordinator. pub type CoordinatorPublicKey = PublicEncryptKey; @@ -86,13 +73,16 @@ pub type SumDict = HashMap; +pub type LocalSeedDict = HashMap; /// A dictionary created during the update phase of the protocol. The global seed dictionary is /// built from the local seed dictionaries sent by the update participants. It maps each sum /// participant to the encrypted masking seeds of all the update participants. -type SeedDict = - HashMap>; +pub type SeedDict = + HashMap>; -/// A 32-byte hash that identifies a model mask computed by a sum participant. -pub type MaskHash = sodiumoxide::crypto::hash::sha256::Digest; +use thiserror::Error; + +#[derive(Error, Debug)] +#[error("initialization failed: insufficient system entropy to generate secrets")] +pub struct InitError; diff --git a/rust/src/macros.rs b/rust/src/macros.rs deleted file mode 100644 index 63fcbd32b..000000000 --- a/rust/src/macros.rs +++ /dev/null @@ -1,93 +0,0 @@ -#[macro_export] -/// Define field accessor methods for a trait to be implemented on a corresponding structure. -/// -/// # Example -/// -/// Writing `define_trait_fields!(bytes, Vec);` will generate the following trait method -/// signatures: -/// ```text -/// /// Get a reference to the bytes field. -/// fn bytes(&self) -> &Vec; -/// -/// /// Get a mutable reference to the bytes field. -/// fn bytes_mut(&mut self) -> &mut Vec; -/// ``` -/// The argument-tuples can be repeated by delimiting them with a semicolon. -macro_rules! define_trait_fields { - ($($name:ident, $type:ty);+ $(;)?) => { - paste::item! { - $( - /// Get a reference to the $name field. - fn $name(&self) -> &$type; - - /// Get a mutable reference to the $name field. - fn [<$name _mut>](&mut self) -> &mut $type; - - )+ - } - }; -} - -#[macro_export] -/// Derive field accessor methods for a trait implemented on a corresponding structure. -/// -/// # Example -/// -/// Writing `derive_trait_fields!(bytes, Vec);` will generate the following trait method for a -/// corresponding structure containing the field `bytes: Vec`: -/// ```text -/// /// Get a reference to the bytes field. -/// fn bytes(&self) -> &Vec { -/// &self.bytes -/// } -/// -/// /// Get a mutable reference to the bytes field. -/// fn bytes_mut(&mut self) -> &mut Vec { -/// &mut self.bytes -/// } -/// ``` -/// The argument-tuples can be repeated by delimiting them with a semicolon. -macro_rules! derive_trait_fields { - ($($name:ident, $type:ty);+ $(;)?) => { - paste::item! { - $( - /// Get a reference to the $name field. - fn $name(&self) -> &$type { - &self.$name - } - - /// Get a mutable reference to the $name field. - fn [<$name _mut>](&mut self) -> &mut $type { - &mut self.$name - } - )+ - - } - }; -} - -#[macro_export] -/// Derive field accessor methods for a structure. -/// -/// # Example -/// -/// Writing `derive_struct_fields!(bytes, Vec);` will generate the following struct method for a -/// corresponding structure containing the field `bytes: Vec`: -/// ```text -/// /// Get a reference to the bytes field. -/// pub fn bytes(&self) -> &Vec { -/// &self.bytes -/// } -/// ``` -/// The argument-tuples can be repeated by delimiting them with a semicolon. -macro_rules! derive_struct_fields { - ($($name:ident, $type:ty);+ $(;)?) => { - $( - /// Get a reference to the $name field. - pub fn $name(&self) -> &$type { - &self.$name - } - )+ - - }; -} diff --git a/rust/src/mask/config.rs b/rust/src/mask/config.rs deleted file mode 100644 index 4bc1a2e88..000000000 --- a/rust/src/mask/config.rs +++ /dev/null @@ -1,648 +0,0 @@ -use std::convert::{TryFrom, TryInto}; - -use num::{ - bigint::{BigInt, BigUint}, - rational::Ratio, - traits::{identities::Zero, pow::Pow, Num}, -}; - -use crate::PetError; - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -/// A mask configuration. -pub struct MaskConfig { - name: MaskConfigs, - add_shift: Ratio, - exp_shift: BigInt, - order: BigUint, -} - -impl MaskConfig { - derive_struct_fields!( - name, MaskConfigs; - add_shift, Ratio; - exp_shift, BigInt; - order, BigUint; - ); - - /// Get the number of bytes needed to represent the largest element of the finite group. - pub fn element_len(&self) -> usize { - if self.order.is_zero() { - 1 - } else { - (self.order() - BigUint::from(1_usize)).to_bytes_le().len() - } - } - - /// Serialize the mask configuration into bytes. - pub fn serialize(&self) -> Vec { - [ - (self.name.group_type as u8).to_le_bytes(), - (self.name.data_type as u8).to_le_bytes(), - (self.name.bound_type as u8).to_le_bytes(), - (self.name.model_type as u8).to_le_bytes(), - ] - .concat() - } - - /// Deserialize the mask configuration from bytes. Fails if any of its parts is invalid. - pub fn deserialize(bytes: &[u8]) -> Result { - if bytes.len() == 4 { - Ok(MaskConfigs { - group_type: bytes[0].try_into()?, - data_type: bytes[1].try_into()?, - bound_type: bytes[2].try_into()?, - model_type: bytes[3].try_into()?, - } - .config()) - } else { - Err(PetError::InvalidMask) - } - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -/// The order of the finite group. -pub enum GroupType { - /// A finite group of exact integer order. - Integer, - /// A finite group of prime order. - Prime, - /// A finite group of power-of-two order. - Power2, -} - -impl TryFrom for GroupType { - type Error = PetError; - - /// Get the group type. Fails if the encoding is unknown. - fn try_from(byte: u8) -> Result { - match byte { - 0 => Ok(Self::Integer), - 1 => Ok(Self::Prime), - 2 => Ok(Self::Power2), - _ => Err(Self::Error::InvalidMask), - } - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -/// The data type of the numbers to be masked. -pub enum DataType { - /// Numbers of type f32. - F32, - /// Numbers of type f64. - F64, - /// Numbers of type i32. - I32, - /// Numbers of type i64. - I64, -} - -impl TryFrom for DataType { - type Error = PetError; - - /// Get the data type. Fails if the encoding is unknown. - fn try_from(byte: u8) -> Result { - match byte { - 0 => Ok(Self::F32), - 1 => Ok(Self::F64), - 2 => Ok(Self::I32), - 3 => Ok(Self::I64), - _ => Err(Self::Error::InvalidMask), - } - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -/// The bounds of the numbers to be masked. -pub enum BoundType { - /// Numbers absolutely bounded by 1. - B0 = 0, - /// Numbers absolutely bounded by 100. - B2 = 2, - /// Numbers absolutely bounded by 10_000. - B4 = 4, - /// Numbers absolutely bounded by 1_000_000. - B6 = 6, - /// Numbers absolutely bounded by their data types' maximum absolute value. - Bmax = 255, -} - -impl TryFrom for BoundType { - type Error = PetError; - - /// Get the bound type. Fails if the encoding is unknown. - fn try_from(byte: u8) -> Result { - match byte { - 0 => Ok(Self::B0), - 2 => Ok(Self::B2), - 4 => Ok(Self::B4), - 6 => Ok(Self::B6), - 255 => Ok(Self::Bmax), - _ => Err(Self::Error::InvalidMask), - } - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -#[repr(u8)] -/// The number of models to be aggregated at most. -pub enum ModelType { - /// At most 1_000 models to be aggregated. - M3 = 3, - /// At most 1_000_000 models to be aggregated. - M6 = 6, - /// At most 1_000_000_000 models to be aggregated. - M9 = 9, - /// At most 1_000_000_000_000 models to be aggregated. - M12 = 12, -} - -impl TryFrom for ModelType { - type Error = PetError; - - /// Get the model type. Fails if the encoding is unknown. - fn try_from(byte: u8) -> Result { - match byte { - 3 => Ok(Self::M3), - 6 => Ok(Self::M6), - 9 => Ok(Self::M9), - 12 => Ok(Self::M12), - _ => Err(Self::Error::InvalidMask), - } - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] -/// A mask configuration name. Consists of identifiers for its parts: -/// - the order of the finite group -/// - the data type of the numbers to be masked -/// - the bounds of the numbers to be masked -/// - the number of models to be aggregated at most -pub struct MaskConfigs { - group_type: GroupType, - data_type: DataType, - bound_type: BoundType, - model_type: ModelType, -} - -impl MaskConfigs { - derive_struct_fields!( - group_type, GroupType; - data_type, DataType; - bound_type, BoundType; - model_type, ModelType; - ); - - /// Create a mask configuration name from its parts. - pub fn from_parts( - group_type: GroupType, - data_type: DataType, - bound_type: BoundType, - model_type: ModelType, - ) -> Self { - MaskConfigs { - group_type, - data_type, - bound_type, - model_type, - } - } - - /// Get the mask configuration corresponding to the name. - pub fn config(&self) -> MaskConfig { - use BoundType::{Bmax, B0, B2, B4, B6}; - use DataType::{F32, F64, I32, I64}; - use GroupType::{Integer, Power2, Prime}; - use ModelType::{M12, M3, M6, M9}; - - let name = *self; - let add_shift = match self.bound_type { - B0 => Ratio::from_integer(BigInt::from(1)), - B2 => Ratio::from_integer(BigInt::from(100)), - B4 => Ratio::from_integer(BigInt::from(10_000)), - B6 => Ratio::from_integer(BigInt::from(1_000_000)), - Bmax => match self.data_type { - // safe unwraps: all numbers are finite - F32 => Ratio::from_float(f32::MAX).unwrap(), - F64 => Ratio::from_float(f64::MAX).unwrap(), - I32 => Ratio::from_integer(-BigInt::from(i32::MIN)), - I64 => Ratio::from_integer(-BigInt::from(i64::MIN)), - }, - }; - let exp_shift = match self.data_type { - F32 => match self.bound_type { - B0 | B2 | B4 | B6 => BigInt::from(10).pow(10_u8), - Bmax => BigInt::from(10).pow(45_u8), - }, - F64 => match self.bound_type { - B0 | B2 | B4 | B6 => BigInt::from(10).pow(20_u8), - Bmax => BigInt::from(10).pow(324_u16), - }, - I32 | I64 => BigInt::from(1), - }; - let order = match self.group_type { - Integer => match self.data_type { - F32 => match self.bound_type { - B0 => match self.model_type { - // safe unwraps: radix and all strings are valid - M3 => BigUint::from_str_radix("20_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_000_000_000", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("2_000_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_000", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("200_000_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("20_000_000_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000_000", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - } - } - F64 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("200_000_000_000_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", 10).unwrap(), - } - } - I32 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("2_000", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("200_000", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_000", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("20_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_000", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("2_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000_000_000", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("4_294_967_295_000", 10).unwrap(), - M6 => BigUint::from_str_radix("4_294_967_295_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("4_294_967_295_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("4_294_967_295_000_000_000_000", 10).unwrap(), - } - } - I64 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("2_000", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("200_000", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_000", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("20_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_000", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("2_000_000_000", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000_000_000", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("18_446_744_073_709_551_615_000", 10).unwrap(), - M6 => BigUint::from_str_radix("18_446_744_073_709_551_615_000_000", 10).unwrap(), - M9 => BigUint::from_str_radix("18_446_744_073_709_551_615_000_000_000", 10).unwrap(), - M12 => BigUint::from_str_radix("18_446_744_073_709_551_615_000_000_000_000", 10).unwrap(), - } - } - } - Prime => match self.data_type { - F32 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("20_000_000_000_021", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_000_000_003", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_000_000_011", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_000_000_003", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("2_000_000_000_000_021", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000_000_000_000_057", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000_000_000_000_069", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_003", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("200_000_000_000_000_003", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_000_000_000_000_089", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_000_000_000_000_069", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_027", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("20_000_000_000_000_000_011", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_000_000_000_000_003", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_009", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000_131", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_281", 10).unwrap(), - M6 => BigUint::from_str_radix("680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_323", 10).unwrap(), - M9 => BigUint::from_str_radix("680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_191", 10).unwrap(), - M12 => BigUint::from_str_radix("680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_083", 10).unwrap(), - } - } - F64 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("200_000_000_000_000_000_000_069", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_027", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_017", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000_159", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_009", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000_131", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000_000_047", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_000_000_000_000_000_000_203", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_000_039", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_000_000_071", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_000_000_000_017", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000_000_000_000_000_000_000_000_041", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_017", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000_159", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000_000_003", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_000_000_000_000_000_000_000_000_023", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("359_538_626_972_463_140_000_000_000_000_000_000_000_593_874_019_667_231_666_067_439_096_529_924_969_333_439_983_391_110_599_943_465_644_007_133_099_721_551_828_263_813_044_710_323_667_390_405_279_670_626_898_022_875_314_671_948_577_301_533_414_396_469_719_048_504_306_012_596_386_638_859_340_084_030_210_314_832_025_518_258_115_226_051_894_034_477_843_584_650_149_420_090_374_373_134_876_775_786_923_748_346_298_936_467_612_015_276_401_624_887_654_050_299_443_392_510_555_689_981_501_608_709_494_004_423_956_258_647_440_955_320_257_123_787_935_493_476_104_132_776_728_548_437_783_283_112_428_445_450_269_488_453_346_610_914_359_272_368_862_786_051_728_965_455_746_393_095_846_720_860_347_644_662_201_994_241_194_193_316_457_656_284_847_050_135_299_403_149_697_261_199_957_835_824_000_531_233_031_619_352_921_347_101_423_914_861_961_738_035_659_301", 10).unwrap(), - M6 => BigUint::from_str_radix("359_538_626_972_463_139_999_999_999_999_999_999_999_903_622_106_309_601_840_402_558_296_261_360_055_843_460_163_714_984_640_183_652_353_129_826_112_739_444_431_322_400_938_984_152_600_575_421_591_212_739_537_896_016_542_591_595_727_264_024_538_428_559_469_178_136_611_680_881_710_150_818_089_794_351_154_869_285_409_959_876_691_068_635_451_827_253_162_844_058_791_343_487_286_852_635_234_799_336_668_682_655_217_329_655_102_622_197_942_194_212_857_658_834_043_465_713_831_143_523_811_067_060_369_640_438_677_832_007_091_511_212_788_398_470_391_285_320_720_769_417_737_628_120_102_221_909_739_846_753_580_817_462_645_602_854_496_103_866_327_474_145_187_363_329_320_852_679_912_679_009_543_036_760_757_409_720_574_191_338_832_841_104_183_169_976_025_577_743_061_881_721_861_634_977_765_641_182_996_194_573_448_626_763_720_938_201_976_656_541_039_724_303", 10).unwrap(), - M9 => BigUint::from_str_radix("359_538_626_972_463_139_999_999_999_999_999_999_999_904_930_781_891_526_077_660_862_016_966_437_766_478_934_820_885_791_914_528_679_207_262_530_042_483_798_832_910_003_057_874_958_310_694_484_517_139_841_166_977_272_287_522_418_122_134_527_125_053_808_273_636_647_181_903_383_717_418_169_782_215_585_647_900_802_728_035_567_327_931_187_710_919_458_230_957_036_511_507_150_288_137_858_111_024_099_126_399_746_768_695_036_546_643_813_753_385_062_385_762_652_380_150_346_615_796_407_577_297_605_069_883_839_431_646_689_072_072_214_687_584_099_356_273_959_025_519_093_953_786_032_481_175_596_842_406_101_871_239_892_163_505_527_137_519_569_046_747_947_203_065_300_865_116_331_411_924_515_285_552_096_042_635_874_474_960_733_445_241_451_746_509_870_642_272_026_256_695_499_704_624_475_309_137_281_644_358_183_373_160_068_523_639_023_207_643_484_888_657_559_597", 10).unwrap(), - M12 => BigUint::from_str_radix("359_538_626_972_463_139_999_999_999_999_999_999_999_904_931_540_467_867_407_238_817_633_447_114_203_759_664_620_787_471_913_925_990_313_859_370_016_783_101_785_327_523_046_787_247_090_978_931_042_236_128_228_564_142_680_745_383_377_953_776_024_143_512_065_781_667_978_525_748_300_241_659_425_164_472_387_573_470_260_831_720_974_578_793_447_369_507_661_739_490_218_806_790_001_765_109_117_055_431_552_295_585_457_639_803_896_262_637_528_011_897_242_316_426_079_400_392_728_240_523_639_775_219_294_589_603_009_325_941_759_217_573_340_626_063_716_838_671_315_192_395_974_939_441_284_468_885_927_433_422_082_497_928_190_254_190_935_717_337_452_741_850_223_510_814_859_331_413_287_559_285_438_144_477_756_395_583_878_761_313_295_130_567_342_888_620_541_025_745_968_373_350_261_259_032_809_052_052_475_301_496_416_128_372_300_050_762_773_363_722_300_553_930_211_649", 10).unwrap(), - } - } - I32 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("2_003", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_003", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_011", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_003", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("200_003", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_033", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_041", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_027", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("20_000_003", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_089", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_021", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_003", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("2_000_000_011", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000_000_003", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000_000_021", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000_000_057", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("4_294_967_295_061", 10).unwrap(), - M6 => BigUint::from_str_radix("4_294_967_295_000_079", 10).unwrap(), - M9 => BigUint::from_str_radix("4_294_967_295_000_000_023", 10).unwrap(), - M12 => BigUint::from_str_radix("4_294_967_295_000_000_000_001", 10).unwrap(), - } - } - I64 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("2_003", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_003", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_011", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_003", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("200_003", 10).unwrap(), - M6 => BigUint::from_str_radix("200_000_033", 10).unwrap(), - M9 => BigUint::from_str_radix("200_000_000_041", 10).unwrap(), - M12 => BigUint::from_str_radix("200_000_000_000_027", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("20_000_003", 10).unwrap(), - M6 => BigUint::from_str_radix("20_000_000_089", 10).unwrap(), - M9 => BigUint::from_str_radix("20_000_000_000_021", 10).unwrap(), - M12 => BigUint::from_str_radix("20_000_000_000_000_003", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("2_000_000_011", 10).unwrap(), - M6 => BigUint::from_str_radix("2_000_000_000_003", 10).unwrap(), - M9 => BigUint::from_str_radix("2_000_000_000_000_021", 10).unwrap(), - M12 => BigUint::from_str_radix("2_000_000_000_000_000_057", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("18_446_744_073_709_551_615_139", 10).unwrap(), - M6 => BigUint::from_str_radix("18_446_744_073_709_551_615_000_053", 10).unwrap(), - M9 => BigUint::from_str_radix("18_446_744_073_709_551_615_000_000_133", 10).unwrap(), - M12 => BigUint::from_str_radix("18_446_744_073_709_551_615_000_000_000_199", 10).unwrap(), - } - } - }, - Power2 => match self.data_type { - F32 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("35_184_372_088_832", 10).unwrap(), - M6 => BigUint::from_str_radix("36_028_797_018_963_968", 10).unwrap(), - M9 => BigUint::from_str_radix("36_893_488_147_419_103_232", 10).unwrap(), - M12 => BigUint::from_str_radix("37_778_931_862_957_161_709_568", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("2_251_799_813_685_248", 10).unwrap(), - M6 => BigUint::from_str_radix("2_305_843_009_213_693_952", 10).unwrap(), - M9 => BigUint::from_str_radix("2_361_183_241_434_822_606_848", 10).unwrap(), - M12 => BigUint::from_str_radix("2_417_851_639_229_258_349_412_352", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("288_230_376_151_711_744", 10).unwrap(), - M6 => BigUint::from_str_radix("295_147_905_179_352_825_856", 10).unwrap(), - M9 => BigUint::from_str_radix("302_231_454_903_657_293_676_544", 10).unwrap(), - M12 => BigUint::from_str_radix("309_485_009_821_345_068_724_781_056", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("36_893_488_147_419_103_232", 10).unwrap(), - M6 => BigUint::from_str_radix("37_778_931_862_957_161_709_568", 10).unwrap(), - M9 => BigUint::from_str_radix("38_685_626_227_668_133_590_597_632", 10).unwrap(), - M12 => BigUint::from_str_radix("39_614_081_257_132_168_796_771_975_168", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("994_646_472_819_573_284_310_764_496_293_641_680_200_912_301_594_695_434_880_927_953_786_318_994_025_066_751_066_112", 10).unwrap(), - M6 => BigUint::from_str_radix("1_018_517_988_167_243_043_134_222_844_204_689_080_525_734_196_832_968_125_318_070_224_677_190_649_881_668_353_091_698_688", 10).unwrap(), - M9 => BigUint::from_str_radix("1_042_962_419_883_256_876_169_444_192_465_601_618_458_351_817_556_959_360_325_703_910_069_443_225_478_828_393_565_899_456_512", 10).unwrap(), - M12 => BigUint::from_str_radix("1_067_993_517_960_455_041_197_510_853_084_776_057_301_352_261_178_326_384_973_520_803_911_109_862_890_320_275_011_481_043_468_288", 10).unwrap(), - } - } - F64 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("302_231_454_903_657_293_676_544", 10).unwrap(), - M6 => BigUint::from_str_radix("309_485_009_821_345_068_724_781_056", 10).unwrap(), - M9 => BigUint::from_str_radix("316_912_650_057_057_350_374_175_801_344", 10).unwrap(), - M12 => BigUint::from_str_radix("324_518_553_658_426_726_783_156_020_576_256", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("38_685_626_227_668_133_590_597_632", 10).unwrap(), - M6 => BigUint::from_str_radix("39_614_081_257_132_168_796_771_975_168", 10).unwrap(), - M9 => BigUint::from_str_radix("20_282_409_603_651_670_423_947_251_286_016", 10).unwrap(), - M12 => BigUint::from_str_radix("20_769_187_434_139_310_514_121_985_316_880_384", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("2_475_880_078_570_760_549_798_248_448", 10).unwrap(), - M6 => BigUint::from_str_radix("2_535_301_200_456_458_802_993_406_410_752", 10).unwrap(), - M9 => BigUint::from_str_radix("2_596_148_429_267_413_814_265_248_164_610_048", 10).unwrap(), - M12 => BigUint::from_str_radix("2_658_455_991_569_831_745_807_614_120_560_689_152", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("316_912_650_057_057_350_374_175_801_344", 10).unwrap(), - M6 => BigUint::from_str_radix("324_518_553_658_426_726_783_156_020_576_256", 10).unwrap(), - M9 => BigUint::from_str_radix("332_306_998_946_228_968_225_951_765_070_086_144", 10).unwrap(), - M12 => BigUint::from_str_radix("340_282_366_920_938_463_463_374_607_431_768_211_456", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("596_143_540_225_991_923_146_302_416_688_458_341_289_203_474_674_553_062_792_993_127_033_853_365_765_018_588_197_722_567_551_977_295_508_215_323_031_793_155_057_153_946_025_631_943_349_443_566_464_703_583_960_364_782_216_884_718_655_637_955_371_883_889_285_523_680_681_542_682_622_992_485_998_454_422_254_346_205_188_269_982_058_330_848_165_814_218_528_432_304_958_458_516_472_675_321_199_923_576_436_128_746_194_040_030_388_187_813_654_706_961_312_852_788_047_760_914_640_519_973_439_182_188_222_756_017_424_664_821_230_981_616_162_111_762_973_371_192_278_908_910_941_031_147_045_555_738_506_834_254_728_517_124_812_756_790_583_181_174_762_115_337_827_697_771_072_593_076_558_961_853_936_203_969_690_859_453_400_618_497_370_766_001_868_317_217_344_149_071_638_768_630_396_860_838_478_405_181_466_899_321_747_678_290_733_613_480_879_657_473_540_096", 10).unwrap(), - M6 => BigUint::from_str_radix("610_450_985_191_415_729_301_813_674_688_981_341_480_144_358_066_742_336_300_024_962_082_665_846_543_379_034_314_467_909_173_224_750_600_412_490_784_556_190_778_525_640_730_247_109_989_830_212_059_856_469_975_413_536_990_089_951_903_373_266_300_809_102_628_376_249_017_899_707_005_944_305_662_417_328_388_450_514_112_788_461_627_730_788_521_793_759_773_114_680_277_461_520_868_019_528_908_721_742_270_595_836_102_696_991_117_504_321_182_419_928_384_361_254_960_907_176_591_892_452_801_722_560_740_102_161_842_856_776_940_525_174_950_002_445_284_732_100_893_602_724_803_615_894_574_649_076_230_998_276_842_001_535_808_262_953_557_177_522_956_406_105_935_562_517_578_335_310_396_376_938_430_672_864_963_440_080_282_233_341_307_664_385_913_156_830_560_408_649_358_099_077_526_385_498_601_886_905_822_104_905_469_622_569_711_220_204_420_769_252_905_058_304", 10).unwrap(), - M9 => BigUint::from_str_radix("625_101_808_836_009_706_805_057_202_881_516_893_675_667_822_660_344_152_371_225_561_172_649_826_860_420_131_138_015_138_993_382_144_614_822_390_563_385_539_357_210_256_107_773_040_629_586_137_149_293_025_254_823_461_877_852_110_749_054_224_692_028_521_091_457_278_994_329_299_974_086_968_998_315_344_269_773_326_451_495_384_706_796_327_446_316_810_007_669_432_604_120_597_368_851_997_602_531_064_085_090_136_169_161_718_904_324_424_890_798_006_665_585_925_079_968_948_830_097_871_668_963_902_197_864_613_727_085_339_587_097_779_148_802_503_971_565_671_315_049_190_198_902_676_044_440_654_060_542_235_486_209_572_667_661_264_442_549_783_507_359_852_478_016_018_000_215_357_845_889_984_953_009_013_722_562_642_209_006_941_499_048_331_175_072_594_493_858_456_942_693_455_387_018_750_568_332_191_561_835_423_200_893_511_384_289_489_326_867_714_974_779_703_296", 10).unwrap(), - M12 => BigUint::from_str_radix("640_104_252_248_073_939_768_378_575_750_673_299_123_883_850_404_192_412_028_134_974_640_793_422_705_070_214_285_327_502_329_223_316_085_578_127_936_906_792_301_783_302_254_359_593_604_696_204_440_876_057_860_939_224_962_920_561_407_031_526_084_637_205_597_652_253_690_193_203_173_465_056_254_274_912_532_247_886_286_331_273_939_759_439_305_028_413_447_853_498_986_619_491_705_704_445_544_991_809_623_132_299_437_221_600_158_028_211_088_177_158_825_559_987_281_888_203_602_020_220_589_019_035_850_613_364_456_535_387_737_188_125_848_373_764_066_883_247_426_610_370_763_676_340_269_507_229_757_995_249_137_878_602_411_685_134_789_170_978_311_536_488_937_488_402_432_220_526_434_191_344_591_881_230_051_904_145_622_023_108_095_025_491_123_274_336_761_711_059_909_318_098_316_307_200_581_972_164_159_319_473_357_714_955_657_512_437_070_712_540_134_174_416_175_104", 10).unwrap(), - } - } - I32 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("2_048", 10).unwrap(), - M6 => BigUint::from_str_radix("2_097_152", 10).unwrap(), - M9 => BigUint::from_str_radix("2_147_483_648", 10).unwrap(), - M12 => BigUint::from_str_radix("2_199_023_255_552", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("262_144", 10).unwrap(), - M6 => BigUint::from_str_radix("268_435_456", 10).unwrap(), - M9 => BigUint::from_str_radix("274_877_906_944", 10).unwrap(), - M12 => BigUint::from_str_radix("281_474_976_710_656", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("33_554_432", 10).unwrap(), - M6 => BigUint::from_str_radix("34_359_738_368", 10).unwrap(), - M9 => BigUint::from_str_radix("35_184_372_088_832", 10).unwrap(), - M12 => BigUint::from_str_radix("36_028_797_018_963_968", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("2_147_483_648", 10).unwrap(), - M6 => BigUint::from_str_radix("2_199_023_255_552", 10).unwrap(), - M9 => BigUint::from_str_radix("2_251_799_813_685_248", 10).unwrap(), - M12 => BigUint::from_str_radix("2_305_843_009_213_693_952", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("4_398_046_511_104", 10).unwrap(), - M6 => BigUint::from_str_radix("4_503_599_627_370_496", 10).unwrap(), - M9 => BigUint::from_str_radix("4_611_686_018_427_387_904", 10).unwrap(), - M12 => BigUint::from_str_radix("4_722_366_482_869_645_213_696", 10).unwrap(), - } - } - I64 => match self.bound_type { - B0 => match self.model_type { - M3 => BigUint::from_str_radix("2_048", 10).unwrap(), - M6 => BigUint::from_str_radix("2_097_152", 10).unwrap(), - M9 => BigUint::from_str_radix("2_147_483_648", 10).unwrap(), - M12 => BigUint::from_str_radix("2_199_023_255_552", 10).unwrap(), - } - B2 => match self.model_type { - M3 => BigUint::from_str_radix("262_144", 10).unwrap(), - M6 => BigUint::from_str_radix("268_435_456", 10).unwrap(), - M9 => BigUint::from_str_radix("274_877_906_944", 10).unwrap(), - M12 => BigUint::from_str_radix("281_474_976_710_656", 10).unwrap(), - } - B4 => match self.model_type { - M3 => BigUint::from_str_radix("33_554_432", 10).unwrap(), - M6 => BigUint::from_str_radix("34_359_738_368", 10).unwrap(), - M9 => BigUint::from_str_radix("35_184_372_088_832", 10).unwrap(), - M12 => BigUint::from_str_radix("36_028_797_018_963_968", 10).unwrap(), - } - B6 => match self.model_type { - M3 => BigUint::from_str_radix("2_147_483_648", 10).unwrap(), - M6 => BigUint::from_str_radix("2_199_023_255_552", 10).unwrap(), - M9 => BigUint::from_str_radix("2_251_799_813_685_248", 10).unwrap(), - M12 => BigUint::from_str_radix("2_305_843_009_213_693_952", 10).unwrap(), - } - Bmax => match self.model_type { - M3 => BigUint::from_str_radix("18_889_465_931_478_580_854_784", 10).unwrap(), - M6 => BigUint::from_str_radix("19_342_813_113_834_066_795_298_816", 10).unwrap(), - M9 => BigUint::from_str_radix("19_807_040_628_566_084_398_385_987_584", 10).unwrap(), - M12 => BigUint::from_str_radix("20_282_409_603_651_670_423_947_251_286_016", 10).unwrap(), - } - } - } - }; - MaskConfig { - name, - add_shift, - exp_shift, - order, - } - } -} diff --git a/rust/src/mask/config/mod.rs b/rust/src/mask/config/mod.rs new file mode 100644 index 000000000..f643607de --- /dev/null +++ b/rust/src/mask/config/mod.rs @@ -0,0 +1,606 @@ +pub(crate) mod serialization; + +use std::convert::TryFrom; +use thiserror::Error; + +use num::{ + bigint::{BigInt, BigUint}, + rational::Ratio, + traits::{pow::Pow, Num}, +}; + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[repr(u8)] +/// The order of the finite group. +pub enum GroupType { + /// A finite group of exact integer order. + Integer = 0, + /// A finite group of prime order. + Prime = 1, + /// A finite group of power-of-two order. + Power2 = 2, +} + +impl TryFrom for GroupType { + type Error = InvalidMaskConfig; + + /// Get the group type. Fails if the encoding is unknown. + fn try_from(byte: u8) -> Result { + use GroupType::{Integer, Power2, Prime}; + + Ok(match byte { + 0 => Integer, + 1 => Prime, + 2 => Power2, + _ => return Err(InvalidMaskConfig::GroupType), + }) + } +} + +#[derive(Debug, Error)] +pub enum InvalidMaskConfig { + #[error("invalid group type")] + GroupType, + #[error("invalid data type")] + DataType, + #[error("invalid bound type")] + BoundType, + #[error("invalid model type")] + ModelType, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[repr(u8)] +/// The data type of the numbers to be masked. +pub enum DataType { + /// Numbers of type f32. + F32 = 0, + /// Numbers of type f64. + F64 = 1, + /// Numbers of type i32. + I32 = 2, + /// Numbers of type i64. + I64 = 3, +} + +impl TryFrom for DataType { + type Error = InvalidMaskConfig; + + /// Get the data type. Fails if the encoding is unknown. + fn try_from(byte: u8) -> Result { + use DataType::{F32, F64, I32, I64}; + + Ok(match byte { + 0 => F32, + 1 => F64, + 2 => I32, + 3 => I64, + _ => return Err(InvalidMaskConfig::DataType), + }) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[repr(u8)] +/// The bounds of the numbers to be masked. +pub enum BoundType { + /// Numbers absolutely bounded by 1. + B0 = 0, + /// Numbers absolutely bounded by 100. + B2 = 2, + /// Numbers absolutely bounded by 10_000. + B4 = 4, + /// Numbers absolutely bounded by 1_000_000. + B6 = 6, + /// Numbers absolutely bounded by their data types' maximum absolute value. + Bmax = 255, +} + +impl TryFrom for BoundType { + type Error = InvalidMaskConfig; + + /// Get the bound type. Fails if the encoding is unknown. + fn try_from(byte: u8) -> Result { + use BoundType::{Bmax, B0, B2, B4, B6}; + + Ok(match byte { + 0 => B0, + 2 => B2, + 4 => B4, + 6 => B6, + 255 => Bmax, + _ => return Err(InvalidMaskConfig::ModelType), + }) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[repr(u8)] +/// The number of models to be aggregated at most. +pub enum ModelType { + /// At most 1_000 models to be aggregated. + M3 = 3, + /// At most 1_000_000 models to be aggregated. + M6 = 6, + /// At most 1_000_000_000 models to be aggregated. + M9 = 9, + /// At most 1_000_000_000_000 models to be aggregated. + M12 = 12, +} + +impl ModelType { + pub fn nb_models_max(&self) -> usize { + 10usize.pow(*self as u8 as u32) + } +} + +impl TryFrom for ModelType { + type Error = InvalidMaskConfig; + + /// Get the model type. Fails if the encoding is unknown. + fn try_from(byte: u8) -> Result { + use ModelType::{M12, M3, M6, M9}; + Ok(match byte { + 3 => M3, + 6 => M6, + 9 => M9, + 12 => M12, + _ => return Err(InvalidMaskConfig::ModelType), + }) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +/// A mask configuration name. Consists of identifiers for its parts: +/// - the order of the finite group +/// - the data type of the numbers to be masked +/// - the bounds of the numbers to be masked +/// - the number of models to be aggregated at most +pub struct MaskConfig { + pub group_type: GroupType, + pub data_type: DataType, + pub bound_type: BoundType, + pub model_type: ModelType, +} + +impl MaskConfig { + /// Return the number of bytes needed to represent a mask element + /// with this configuration + pub fn bytes_per_digit(&self) -> usize { + (self.order().bits() + 7) / 8 + } + + pub fn add_shift(&self) -> Ratio { + use BoundType::{Bmax, B0, B2, B4, B6}; + use DataType::{F32, F64, I32, I64}; + match self.bound_type { + B0 => Ratio::from_integer(BigInt::from(1)), + B2 => Ratio::from_integer(BigInt::from(100)), + B4 => Ratio::from_integer(BigInt::from(10_000)), + B6 => Ratio::from_integer(BigInt::from(1_000_000)), + Bmax => match self.data_type { + // safe unwraps: all numbers are finite + F32 => Ratio::from_float(f32::MAX).unwrap(), + F64 => Ratio::from_float(f64::MAX).unwrap(), + I32 => Ratio::from_integer(-BigInt::from(i32::MIN)), + I64 => Ratio::from_integer(-BigInt::from(i64::MIN)), + }, + } + } + pub fn exp_shift(&self) -> BigInt { + use BoundType::{Bmax, B0, B2, B4, B6}; + use DataType::{F32, F64, I32, I64}; + match self.data_type { + F32 => match self.bound_type { + B0 | B2 | B4 | B6 => BigInt::from(10).pow(10_u8), + Bmax => BigInt::from(10).pow(45_u8), + }, + F64 => match self.bound_type { + B0 | B2 | B4 | B6 => BigInt::from(10).pow(20_u8), + Bmax => BigInt::from(10).pow(324_u16), + }, + I32 | I64 => BigInt::from(1), + } + } + + pub fn order(&self) -> BigUint { + use BoundType::{Bmax, B0, B2, B4, B6}; + use DataType::{F32, F64, I32, I64}; + use GroupType::{Integer, Power2, Prime}; + use ModelType::{M12, M3, M6, M9}; + + let order_str = match self.group_type { + Integer => match self.data_type { + F32 => match self.bound_type { + B0 => match self.model_type { + M3 => "20_000_000_000_000", + M6 => "20_000_000_000_000_000", + M9 => "20_000_000_000_000_000_000", + M12 => "20_000_000_000_000_000_000_000", + } + B2 => match self.model_type { + M3 => "2_000_000_000_000_000", + M6 => "2_000_000_000_000_000_000", + M9 => "2_000_000_000_000_000_000_000", + M12 => "2_000_000_000_000_000_000_000_000", + } + B4 => match self.model_type { + M3 => "200_000_000_000_000_000", + M6 => "200_000_000_000_000_000_000", + M9 => "200_000_000_000_000_000_000_000", + M12 => "200_000_000_000_000_000_000_000_000", + } + B6 => match self.model_type { + M3 => "20_000_000_000_000_000_000", + M6 => "20_000_000_000_000_000_000_000", + M9 => "20_000_000_000_000_000_000_000_000", + M12 => "20_000_000_000_000_000_000_000_000_000", + } + Bmax => match self.model_type { + M3 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", + M6 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", + M9 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", + M12 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", + } + } + F64 => match self.bound_type { + B0 => match self.model_type { + M3 => "200_000_000_000_000_000_000_000", + M6 => "200_000_000_000_000_000_000_000_000", + M9 => "200_000_000_000_000_000_000_000_000_000", + M12 => "200_000_000_000_000_000_000_000_000_000_000", + } + B2 => match self.model_type { + M3 => "20_000_000_000_000_000_000_000_000", + M6 => "20_000_000_000_000_000_000_000_000_000", + M9 => "20_000_000_000_000_000_000_000_000_000_000", + M12 => "20_000_000_000_000_000_000_000_000_000_000_000", + } + B4 => match self.model_type { + M3 => "2_000_000_000_000_000_000_000_000_000", + M6 => "2_000_000_000_000_000_000_000_000_000_000", + M9 => "2_000_000_000_000_000_000_000_000_000_000_000", + M12 => "2_000_000_000_000_000_000_000_000_000_000_000_000", + } + B6 => match self.model_type { + M3 => "200_000_000_000_000_000_000_000_000_000", + M6 => "200_000_000_000_000_000_000_000_000_000_000", + M9 => "200_000_000_000_000_000_000_000_000_000_000_000", + M12 => "200_000_000_000_000_000_000_000_000_000_000_000_000", + } + Bmax => match self.model_type { + M3 => "359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", + M6 => "359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", + M9 => "359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", + M12 => "359_538_626_972_463_100_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000", + } + } + I32 => match self.bound_type { + B0 => match self.model_type { + M3 => "2_000", + M6 => "2_000_000", + M9 => "2_000_000_000", + M12 => "2_000_000_000_000", + } + B2 => match self.model_type { + M3 => "200_000", + M6 => "200_000_000", + M9 => "200_000_000_000", + M12 => "200_000_000_000_000", + } + B4 => match self.model_type { + M3 => "20_000_000", + M6 => "20_000_000_000", + M9 => "20_000_000_000_000", + M12 => "20_000_000_000_000_000", + } + B6 => match self.model_type { + M3 => "2_000_000_000", + M6 => "2_000_000_000_000", + M9 => "2_000_000_000_000_000", + M12 => "2_000_000_000_000_000_000", + } + Bmax => match self.model_type { + M3 => "4_294_967_295_000", + M6 => "4_294_967_295_000_000", + M9 => "4_294_967_295_000_000_000", + M12 => "4_294_967_295_000_000_000_000", + } + } + I64 => match self.bound_type { + B0 => match self.model_type { + M3 => "2_000", + M6 => "2_000_000", + M9 => "2_000_000_000", + M12 => "2_000_000_000_000", + } + B2 => match self.model_type { + M3 => "200_000", + M6 => "200_000_000", + M9 => "200_000_000_000", + M12 => "200_000_000_000_000", + } + B4 => match self.model_type { + M3 => "20_000_000", + M6 => "20_000_000_000", + M9 => "20_000_000_000_000", + M12 => "20_000_000_000_000_000", + } + B6 => match self.model_type { + M3 => "2_000_000_000", + M6 => "2_000_000_000_000", + M9 => "2_000_000_000_000_000", + M12 => "2_000_000_000_000_000_000", + } + Bmax => match self.model_type { + M3 => "18_446_744_073_709_551_615_000", + M6 => "18_446_744_073_709_551_615_000_000", + M9 => "18_446_744_073_709_551_615_000_000_000", + M12 => "18_446_744_073_709_551_615_000_000_000_000", + } + } + } + Prime => match self.data_type { + F32 => match self.bound_type { + B0 => match self.model_type { + M3 => "20_000_000_000_021", + M6 => "20_000_000_000_000_003", + M9 => "20_000_000_000_000_000_011", + M12 => "20_000_000_000_000_000_000_003", + } + B2 => match self.model_type { + M3 => "2_000_000_000_000_021", + M6 => "2_000_000_000_000_000_057", + M9 => "2_000_000_000_000_000_000_069", + M12 => "2_000_000_000_000_000_000_000_003", + } + B4 => match self.model_type { + M3 => "200_000_000_000_000_003", + M6 => "200_000_000_000_000_000_089", + M9 => "200_000_000_000_000_000_000_069", + M12 => "200_000_000_000_000_000_000_000_027", + } + B6 => match self.model_type { + M3 => "20_000_000_000_000_000_011", + M6 => "20_000_000_000_000_000_000_003", + M9 => "20_000_000_000_000_000_000_000_009", + M12 => "20_000_000_000_000_000_000_000_000_131", + } + Bmax => match self.model_type { + M3 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_281", + M6 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_323", + M9 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_191", + M12 => "680_564_700_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_000_083", + } + } + F64 => match self.bound_type { + B0 => match self.model_type { + M3 => "200_000_000_000_000_000_000_069", + M6 => "200_000_000_000_000_000_000_000_027", + M9 => "200_000_000_000_000_000_000_000_000_017", + M12 => "200_000_000_000_000_000_000_000_000_000_159", + } + B2 => match self.model_type { + M3 => "20_000_000_000_000_000_000_000_009", + M6 => "20_000_000_000_000_000_000_000_000_131", + M9 => "20_000_000_000_000_000_000_000_000_000_047", + M12 => "20_000_000_000_000_000_000_000_000_000_000_203", + } + B4 => match self.model_type { + M3 => "2_000_000_000_000_000_000_000_000_039", + M6 => "2_000_000_000_000_000_000_000_000_000_071", + M9 => "2_000_000_000_000_000_000_000_000_000_000_017", + M12 => "2_000_000_000_000_000_000_000_000_000_000_000_041", + } + B6 => match self.model_type { + M3 => "200_000_000_000_000_000_000_000_000_017", + M6 => "200_000_000_000_000_000_000_000_000_000_159", + M9 => "200_000_000_000_000_000_000_000_000_000_000_003", + M12 => "200_000_000_000_000_000_000_000_000_000_000_000_023", + } + Bmax => match self.model_type { + M3 => "359_538_626_972_463_140_000_000_000_000_000_000_000_593_874_019_667_231_666_067_439_096_529_924_969_333_439_983_391_110_599_943_465_644_007_133_099_721_551_828_263_813_044_710_323_667_390_405_279_670_626_898_022_875_314_671_948_577_301_533_414_396_469_719_048_504_306_012_596_386_638_859_340_084_030_210_314_832_025_518_258_115_226_051_894_034_477_843_584_650_149_420_090_374_373_134_876_775_786_923_748_346_298_936_467_612_015_276_401_624_887_654_050_299_443_392_510_555_689_981_501_608_709_494_004_423_956_258_647_440_955_320_257_123_787_935_493_476_104_132_776_728_548_437_783_283_112_428_445_450_269_488_453_346_610_914_359_272_368_862_786_051_728_965_455_746_393_095_846_720_860_347_644_662_201_994_241_194_193_316_457_656_284_847_050_135_299_403_149_697_261_199_957_835_824_000_531_233_031_619_352_921_347_101_423_914_861_961_738_035_659_301", + M6 => "359_538_626_972_463_139_999_999_999_999_999_999_999_903_622_106_309_601_840_402_558_296_261_360_055_843_460_163_714_984_640_183_652_353_129_826_112_739_444_431_322_400_938_984_152_600_575_421_591_212_739_537_896_016_542_591_595_727_264_024_538_428_559_469_178_136_611_680_881_710_150_818_089_794_351_154_869_285_409_959_876_691_068_635_451_827_253_162_844_058_791_343_487_286_852_635_234_799_336_668_682_655_217_329_655_102_622_197_942_194_212_857_658_834_043_465_713_831_143_523_811_067_060_369_640_438_677_832_007_091_511_212_788_398_470_391_285_320_720_769_417_737_628_120_102_221_909_739_846_753_580_817_462_645_602_854_496_103_866_327_474_145_187_363_329_320_852_679_912_679_009_543_036_760_757_409_720_574_191_338_832_841_104_183_169_976_025_577_743_061_881_721_861_634_977_765_641_182_996_194_573_448_626_763_720_938_201_976_656_541_039_724_303", + M9 => "359_538_626_972_463_139_999_999_999_999_999_999_999_904_930_781_891_526_077_660_862_016_966_437_766_478_934_820_885_791_914_528_679_207_262_530_042_483_798_832_910_003_057_874_958_310_694_484_517_139_841_166_977_272_287_522_418_122_134_527_125_053_808_273_636_647_181_903_383_717_418_169_782_215_585_647_900_802_728_035_567_327_931_187_710_919_458_230_957_036_511_507_150_288_137_858_111_024_099_126_399_746_768_695_036_546_643_813_753_385_062_385_762_652_380_150_346_615_796_407_577_297_605_069_883_839_431_646_689_072_072_214_687_584_099_356_273_959_025_519_093_953_786_032_481_175_596_842_406_101_871_239_892_163_505_527_137_519_569_046_747_947_203_065_300_865_116_331_411_924_515_285_552_096_042_635_874_474_960_733_445_241_451_746_509_870_642_272_026_256_695_499_704_624_475_309_137_281_644_358_183_373_160_068_523_639_023_207_643_484_888_657_559_597", + M12 => "359_538_626_972_463_139_999_999_999_999_999_999_999_904_931_540_467_867_407_238_817_633_447_114_203_759_664_620_787_471_913_925_990_313_859_370_016_783_101_785_327_523_046_787_247_090_978_931_042_236_128_228_564_142_680_745_383_377_953_776_024_143_512_065_781_667_978_525_748_300_241_659_425_164_472_387_573_470_260_831_720_974_578_793_447_369_507_661_739_490_218_806_790_001_765_109_117_055_431_552_295_585_457_639_803_896_262_637_528_011_897_242_316_426_079_400_392_728_240_523_639_775_219_294_589_603_009_325_941_759_217_573_340_626_063_716_838_671_315_192_395_974_939_441_284_468_885_927_433_422_082_497_928_190_254_190_935_717_337_452_741_850_223_510_814_859_331_413_287_559_285_438_144_477_756_395_583_878_761_313_295_130_567_342_888_620_541_025_745_968_373_350_261_259_032_809_052_052_475_301_496_416_128_372_300_050_762_773_363_722_300_553_930_211_649", + } + } + I32 => match self.bound_type { + B0 => match self.model_type { + M3 => "2_003", + M6 => "2_000_003", + M9 => "2_000_000_011", + M12 => "2_000_000_000_003", + } + B2 => match self.model_type { + M3 => "200_003", + M6 => "200_000_033", + M9 => "200_000_000_041", + M12 => "200_000_000_000_027", + } + B4 => match self.model_type { + M3 => "20_000_003", + M6 => "20_000_000_089", + M9 => "20_000_000_000_021", + M12 => "20_000_000_000_000_003", + } + B6 => match self.model_type { + M3 => "2_000_000_011", + M6 => "2_000_000_000_003", + M9 => "2_000_000_000_000_021", + M12 => "2_000_000_000_000_000_057", + } + Bmax => match self.model_type { + M3 => "4_294_967_295_061", + M6 => "4_294_967_295_000_079", + M9 => "4_294_967_295_000_000_023", + M12 => "4_294_967_295_000_000_000_001", + } + } + I64 => match self.bound_type { + B0 => match self.model_type { + M3 => "2_003", + M6 => "2_000_003", + M9 => "2_000_000_011", + M12 => "2_000_000_000_003", + } + B2 => match self.model_type { + M3 => "200_003", + M6 => "200_000_033", + M9 => "200_000_000_041", + M12 => "200_000_000_000_027", + } + B4 => match self.model_type { + M3 => "20_000_003", + M6 => "20_000_000_089", + M9 => "20_000_000_000_021", + M12 => "20_000_000_000_000_003", + } + B6 => match self.model_type { + M3 => "2_000_000_011", + M6 => "2_000_000_000_003", + M9 => "2_000_000_000_000_021", + M12 => "2_000_000_000_000_000_057", + } + Bmax => match self.model_type { + M3 => "18_446_744_073_709_551_615_139", + M6 => "18_446_744_073_709_551_615_000_053", + M9 => "18_446_744_073_709_551_615_000_000_133", + M12 => "18_446_744_073_709_551_615_000_000_000_199", + } + } + }, + Power2 => match self.data_type { + F32 => match self.bound_type { + B0 => match self.model_type { + M3 => "35_184_372_088_832", + M6 => "36_028_797_018_963_968", + M9 => "36_893_488_147_419_103_232", + M12 => "37_778_931_862_957_161_709_568", + } + B2 => match self.model_type { + M3 => "2_251_799_813_685_248", + M6 => "2_305_843_009_213_693_952", + M9 => "2_361_183_241_434_822_606_848", + M12 => "2_417_851_639_229_258_349_412_352", + } + B4 => match self.model_type { + M3 => "288_230_376_151_711_744", + M6 => "295_147_905_179_352_825_856", + M9 => "302_231_454_903_657_293_676_544", + M12 => "309_485_009_821_345_068_724_781_056", + } + B6 => match self.model_type { + M3 => "36_893_488_147_419_103_232", + M6 => "37_778_931_862_957_161_709_568", + M9 => "38_685_626_227_668_133_590_597_632", + M12 => "39_614_081_257_132_168_796_771_975_168", + } + Bmax => match self.model_type { + M3 => "994_646_472_819_573_284_310_764_496_293_641_680_200_912_301_594_695_434_880_927_953_786_318_994_025_066_751_066_112", + M6 => "1_018_517_988_167_243_043_134_222_844_204_689_080_525_734_196_832_968_125_318_070_224_677_190_649_881_668_353_091_698_688", + M9 => "1_042_962_419_883_256_876_169_444_192_465_601_618_458_351_817_556_959_360_325_703_910_069_443_225_478_828_393_565_899_456_512", + M12 => "1_067_993_517_960_455_041_197_510_853_084_776_057_301_352_261_178_326_384_973_520_803_911_109_862_890_320_275_011_481_043_468_288", + } + } + F64 => match self.bound_type { + B0 => match self.model_type { + M3 => "302_231_454_903_657_293_676_544", + M6 => "309_485_009_821_345_068_724_781_056", + M9 => "316_912_650_057_057_350_374_175_801_344", + M12 => "324_518_553_658_426_726_783_156_020_576_256", + } + B2 => match self.model_type { + M3 => "38_685_626_227_668_133_590_597_632", + M6 => "39_614_081_257_132_168_796_771_975_168", + M9 => "20_282_409_603_651_670_423_947_251_286_016", + M12 => "20_769_187_434_139_310_514_121_985_316_880_384", + } + B4 => match self.model_type { + M3 => "2_475_880_078_570_760_549_798_248_448", + M6 => "2_535_301_200_456_458_802_993_406_410_752", + M9 => "2_596_148_429_267_413_814_265_248_164_610_048", + M12 => "2_658_455_991_569_831_745_807_614_120_560_689_152", + } + B6 => match self.model_type { + M3 => "316_912_650_057_057_350_374_175_801_344", + M6 => "324_518_553_658_426_726_783_156_020_576_256", + M9 => "332_306_998_946_228_968_225_951_765_070_086_144", + M12 => "340_282_366_920_938_463_463_374_607_431_768_211_456", + } + Bmax => match self.model_type { + M3 => "596_143_540_225_991_923_146_302_416_688_458_341_289_203_474_674_553_062_792_993_127_033_853_365_765_018_588_197_722_567_551_977_295_508_215_323_031_793_155_057_153_946_025_631_943_349_443_566_464_703_583_960_364_782_216_884_718_655_637_955_371_883_889_285_523_680_681_542_682_622_992_485_998_454_422_254_346_205_188_269_982_058_330_848_165_814_218_528_432_304_958_458_516_472_675_321_199_923_576_436_128_746_194_040_030_388_187_813_654_706_961_312_852_788_047_760_914_640_519_973_439_182_188_222_756_017_424_664_821_230_981_616_162_111_762_973_371_192_278_908_910_941_031_147_045_555_738_506_834_254_728_517_124_812_756_790_583_181_174_762_115_337_827_697_771_072_593_076_558_961_853_936_203_969_690_859_453_400_618_497_370_766_001_868_317_217_344_149_071_638_768_630_396_860_838_478_405_181_466_899_321_747_678_290_733_613_480_879_657_473_540_096", + M6 => "610_450_985_191_415_729_301_813_674_688_981_341_480_144_358_066_742_336_300_024_962_082_665_846_543_379_034_314_467_909_173_224_750_600_412_490_784_556_190_778_525_640_730_247_109_989_830_212_059_856_469_975_413_536_990_089_951_903_373_266_300_809_102_628_376_249_017_899_707_005_944_305_662_417_328_388_450_514_112_788_461_627_730_788_521_793_759_773_114_680_277_461_520_868_019_528_908_721_742_270_595_836_102_696_991_117_504_321_182_419_928_384_361_254_960_907_176_591_892_452_801_722_560_740_102_161_842_856_776_940_525_174_950_002_445_284_732_100_893_602_724_803_615_894_574_649_076_230_998_276_842_001_535_808_262_953_557_177_522_956_406_105_935_562_517_578_335_310_396_376_938_430_672_864_963_440_080_282_233_341_307_664_385_913_156_830_560_408_649_358_099_077_526_385_498_601_886_905_822_104_905_469_622_569_711_220_204_420_769_252_905_058_304", + M9 => "625_101_808_836_009_706_805_057_202_881_516_893_675_667_822_660_344_152_371_225_561_172_649_826_860_420_131_138_015_138_993_382_144_614_822_390_563_385_539_357_210_256_107_773_040_629_586_137_149_293_025_254_823_461_877_852_110_749_054_224_692_028_521_091_457_278_994_329_299_974_086_968_998_315_344_269_773_326_451_495_384_706_796_327_446_316_810_007_669_432_604_120_597_368_851_997_602_531_064_085_090_136_169_161_718_904_324_424_890_798_006_665_585_925_079_968_948_830_097_871_668_963_902_197_864_613_727_085_339_587_097_779_148_802_503_971_565_671_315_049_190_198_902_676_044_440_654_060_542_235_486_209_572_667_661_264_442_549_783_507_359_852_478_016_018_000_215_357_845_889_984_953_009_013_722_562_642_209_006_941_499_048_331_175_072_594_493_858_456_942_693_455_387_018_750_568_332_191_561_835_423_200_893_511_384_289_489_326_867_714_974_779_703_296", + M12 => "640_104_252_248_073_939_768_378_575_750_673_299_123_883_850_404_192_412_028_134_974_640_793_422_705_070_214_285_327_502_329_223_316_085_578_127_936_906_792_301_783_302_254_359_593_604_696_204_440_876_057_860_939_224_962_920_561_407_031_526_084_637_205_597_652_253_690_193_203_173_465_056_254_274_912_532_247_886_286_331_273_939_759_439_305_028_413_447_853_498_986_619_491_705_704_445_544_991_809_623_132_299_437_221_600_158_028_211_088_177_158_825_559_987_281_888_203_602_020_220_589_019_035_850_613_364_456_535_387_737_188_125_848_373_764_066_883_247_426_610_370_763_676_340_269_507_229_757_995_249_137_878_602_411_685_134_789_170_978_311_536_488_937_488_402_432_220_526_434_191_344_591_881_230_051_904_145_622_023_108_095_025_491_123_274_336_761_711_059_909_318_098_316_307_200_581_972_164_159_319_473_357_714_955_657_512_437_070_712_540_134_174_416_175_104", + } + } + I32 => match self.bound_type { + B0 => match self.model_type { + M3 => "2_048", + M6 => "2_097_152", + M9 => "2_147_483_648", + M12 => "2_199_023_255_552", + } + B2 => match self.model_type { + M3 => "262_144", + M6 => "268_435_456", + M9 => "274_877_906_944", + M12 => "281_474_976_710_656", + } + B4 => match self.model_type { + M3 => "33_554_432", + M6 => "34_359_738_368", + M9 => "35_184_372_088_832", + M12 => "36_028_797_018_963_968", + } + B6 => match self.model_type { + M3 => "2_147_483_648", + M6 => "2_199_023_255_552", + M9 => "2_251_799_813_685_248", + M12 => "2_305_843_009_213_693_952", + } + Bmax => match self.model_type { + M3 => "4_398_046_511_104", + M6 => "4_503_599_627_370_496", + M9 => "4_611_686_018_427_387_904", + M12 => "4_722_366_482_869_645_213_696", + } + } + I64 => match self.bound_type { + B0 => match self.model_type { + M3 => "2_048", + M6 => "2_097_152", + M9 => "2_147_483_648", + M12 => "2_199_023_255_552", + } + B2 => match self.model_type { + M3 => "262_144", + M6 => "268_435_456", + M9 => "274_877_906_944", + M12 => "281_474_976_710_656", + } + B4 => match self.model_type { + M3 => "33_554_432", + M6 => "34_359_738_368", + M9 => "35_184_372_088_832", + M12 => "36_028_797_018_963_968", + } + B6 => match self.model_type { + M3 => "2_147_483_648", + M6 => "2_199_023_255_552", + M9 => "2_251_799_813_685_248", + M12 => "2_305_843_009_213_693_952", + } + Bmax => match self.model_type { + M3 => "18_889_465_931_478_580_854_784", + M6 => "19_342_813_113_834_066_795_298_816", + M9 => "19_807_040_628_566_084_398_385_987_584", + M12 => "20_282_409_603_651_670_423_947_251_286_016", + } + } + } + }; + BigUint::from_str_radix(order_str, 10).unwrap() + } +} diff --git a/rust/src/mask/config/serialization.rs b/rust/src/mask/config/serialization.rs new file mode 100644 index 000000000..ba1a0cf32 --- /dev/null +++ b/rust/src/mask/config/serialization.rs @@ -0,0 +1,111 @@ +use anyhow::{anyhow, Context}; +use std::convert::TryInto; + +use crate::{ + mask::MaskConfig, + message::{DecodeError, FromBytes, ToBytes}, +}; + +const GROUP_TYPE_FIELD: usize = 0; +const DATA_TYPE_FIELD: usize = 1; +const BOUND_TYPE_FIELD: usize = 2; +const MODEL_TYPE_FIELD: usize = 3; +pub(crate) const MASK_CONFIG_BUFFER_LEN: usize = MODEL_TYPE_FIELD; + +struct MaskConfigBuffer { + inner: T, +} + +impl> MaskConfigBuffer { + pub fn new(bytes: T) -> Result { + let buffer = Self { inner: bytes }; + buffer + .check_buffer_length() + .context("not a valid MaskConfigBuffer")?; + Ok(buffer) + } + + pub fn new_unchecked(bytes: T) -> Self { + Self { inner: bytes } + } + + pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + let len = self.inner.as_ref().len(); + if len < MODEL_TYPE_FIELD { + return Err(anyhow!( + "invalid buffer length: {} < {}", + len, + MODEL_TYPE_FIELD + )); + } + Ok(()) + } + + pub fn group_type(&self) -> u8 { + self.inner.as_ref()[GROUP_TYPE_FIELD] + } + + pub fn data_type(&self) -> u8 { + self.inner.as_ref()[DATA_TYPE_FIELD] + } + pub fn bound_type(&self) -> u8 { + self.inner.as_ref()[BOUND_TYPE_FIELD] + } + pub fn model_type(&self) -> u8 { + self.inner.as_ref()[MODEL_TYPE_FIELD] + } +} + +impl> MaskConfigBuffer { + pub fn set_group_type(&mut self, value: u8) { + self.inner.as_mut()[GROUP_TYPE_FIELD] = value; + } + + pub fn set_data_type(&mut self, value: u8) { + self.inner.as_mut()[DATA_TYPE_FIELD] = value; + } + pub fn set_bound_type(&mut self, value: u8) { + self.inner.as_mut()[BOUND_TYPE_FIELD] = value; + } + pub fn set_model_type(&mut self, value: u8) { + self.inner.as_mut()[MODEL_TYPE_FIELD] = value; + } +} + +impl ToBytes for MaskConfig { + fn buffer_length(&self) -> usize { + MODEL_TYPE_FIELD + } + + fn to_bytes>(&self, buffer: &mut T) { + let mut writer = MaskConfigBuffer::new_unchecked(buffer.as_mut()); + writer.set_group_type(self.group_type as u8); + writer.set_data_type(self.data_type as u8); + writer.set_bound_type(self.bound_type as u8); + writer.set_model_type(self.model_type as u8); + } +} + +impl FromBytes for MaskConfig { + fn from_bytes>(buffer: &T) -> Result { + let reader = MaskConfigBuffer::new(buffer.as_ref())?; + Ok(Self { + group_type: reader + .group_type() + .try_into() + .context("invalid masking config")?, + data_type: reader + .data_type() + .try_into() + .context("invalid masking config")?, + bound_type: reader + .bound_type() + .try_into() + .context("invalid masking config")?, + model_type: reader + .model_type() + .try_into() + .context("invalid masking config")?, + }) + } +} diff --git a/rust/src/mask/mask_object/mod.rs b/rust/src/mask/mask_object/mod.rs new file mode 100644 index 000000000..45136703f --- /dev/null +++ b/rust/src/mask/mask_object/mod.rs @@ -0,0 +1,44 @@ +pub(crate) mod serialization; + +use std::iter::Iterator; + +use num::bigint::BigUint; +use thiserror::Error; + +use crate::mask::MaskConfig; + +#[derive(Error, Debug)] +#[error("the mask object is invalid: data is incompatible with the masking configuration")] +pub struct InvalidMaskObject; + +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct MaskObject { + pub(crate) data: Vec, + pub(crate) config: MaskConfig, +} + +impl MaskObject { + pub fn new(config: MaskConfig, data: Vec) -> Self { + Self { data, config } + } + + pub fn new_empty(config: MaskConfig) -> Self { + Self { + data: vec![], + config, + } + } + + pub fn new_checked(config: MaskConfig, data: Vec) -> Result { + let obj = Self::new(config, data); + if obj.is_valid() { + Ok(obj) + } else { + Err(InvalidMaskObject) + } + } + + pub fn is_valid(&self) -> bool { + self.data.iter().all(|i| i < &self.config.order()) + } +} diff --git a/rust/src/mask/mask_object/serialization.rs b/rust/src/mask/mask_object/serialization.rs new file mode 100644 index 000000000..47e44357d --- /dev/null +++ b/rust/src/mask/mask_object/serialization.rs @@ -0,0 +1,130 @@ +use anyhow::{anyhow, Context}; +use std::{convert::TryInto, ops::Range}; + +use num::bigint::BigUint; + +use super::MaskObject; +use crate::{ + mask::{config::serialization::MASK_CONFIG_BUFFER_LEN, MaskConfig}, + message::{utils::range, DecodeError, FromBytes, ToBytes}, +}; + +const MASK_CONFIG_FIELD: Range = range(0, MASK_CONFIG_BUFFER_LEN); +const DIGITS_FIELD: Range = range(MASK_CONFIG_FIELD.end, 4); + +struct MaskObjectBuffer { + inner: T, +} +impl> MaskObjectBuffer { + pub fn new(bytes: T) -> Result { + let buffer = Self { inner: bytes }; + buffer + .check_buffer_length() + .context("not a valid MaskObject")?; + Ok(buffer) + } + + pub fn new_unchecked(bytes: T) -> Self { + Self { inner: bytes } + } + + pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + let len = self.inner.as_ref().len(); + if len < DIGITS_FIELD.end { + return Err(anyhow!( + "invalid buffer length: {} < {}", + len, + DIGITS_FIELD.end + )); + } + + let config = MaskConfig::from_bytes(&self.config()).context("invalid MaskObject buffer")?; + let bytes_per_digit = config.bytes_per_digit(); + let (data_length, overflows) = (self.digits() as usize).overflowing_mul(bytes_per_digit); + if overflows { + return Err(anyhow!( + "invalid MaskObject buffer: invalid mask config or digits field" + )); + } + let total_expected_length = DIGITS_FIELD.end + data_length; + if len < total_expected_length { + return Err(anyhow!( + "invalid buffer length: expected {} bytes but buffer has only {} bytes", + total_expected_length, + len + )); + } + Ok(()) + } + + pub fn digits(&self) -> u32 { + // UNWRAP SAFE: the slice is exactly 4 bytes long + u32::from_be_bytes(self.inner.as_ref()[DIGITS_FIELD].try_into().unwrap()) + } + + pub fn config(&self) -> &[u8] { + &self.inner.as_ref()[MASK_CONFIG_FIELD] + } + + pub fn data(&self) -> &[u8] { + &self.inner.as_ref()[DIGITS_FIELD.end..] + } +} + +impl> MaskObjectBuffer { + pub fn set_digits(&mut self, value: u32) { + self.inner.as_mut()[DIGITS_FIELD].copy_from_slice(&value.to_be_bytes()); + } + pub fn config_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[MASK_CONFIG_FIELD] + } + + pub fn data_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[DIGITS_FIELD.end..] + } +} + +impl ToBytes for MaskObject { + fn buffer_length(&self) -> usize { + DIGITS_FIELD.end + self.config.bytes_per_digit() * self.data.len() + } + + fn to_bytes>(&self, buffer: &mut T) { + let mut writer = MaskObjectBuffer::new_unchecked(buffer.as_mut()); + self.config.to_bytes(&mut writer.config_mut()); + writer.set_digits(self.data.len() as u32); + + let mut data = writer.data_mut(); + let bytes_per_digit = self.config.bytes_per_digit(); + + for int in self.data.iter() { + // FIXME: this allocates a vec which is sub-optimal. See + // https://github.com/rust-num/num-bigint/issues/152 + let bytes = int.to_bytes_le(); + // This may panic if the data is invalid and contains + // integers that are bigger than what is expected by the + // configuration. + data[..bytes.len()].copy_from_slice(&bytes[..]); + // padding + for i in bytes.len()..bytes_per_digit { + data[i] = 0; + } + data = &mut data[bytes_per_digit..]; + } + } +} + +impl FromBytes for MaskObject { + fn from_bytes>(buffer: &T) -> Result { + let reader = MaskObjectBuffer::new(buffer.as_ref())?; + + let config = MaskConfig::from_bytes(&reader.config())?; + let mut data = Vec::with_capacity(reader.digits() as usize); + let bytes_per_digit = config.bytes_per_digit(); + for chunk in reader.data().chunks(bytes_per_digit) { + data.push(BigUint::from_bytes_le(chunk)); + } + + Ok(MaskObject { data, config }) + } +} diff --git a/rust/src/mask/masking.rs b/rust/src/mask/masking.rs new file mode 100644 index 000000000..cd89206a0 --- /dev/null +++ b/rust/src/mask/masking.rs @@ -0,0 +1,362 @@ +use rand::{RngCore, SeedableRng}; +use std::iter::{self, Iterator}; + +use num::{ + bigint::{BigInt, BigUint, ToBigInt}, + clamp, + rational::Ratio, + traits::Zero, +}; +use rand_chacha::ChaCha20Rng; + +use crate::mask::{MaskConfig, MaskObject, MaskSeed, Model}; + +use thiserror::Error; + +pub type MaskedModel = MaskObject; +pub type Mask = MaskObject; + +#[derive(Debug)] +pub struct AggregatedModel { + nb_models: usize, + model: MaskedModel, + config: MaskConfig, +} + +#[derive(Debug, Error)] +pub enum UnmaskingError { + #[error("there is no model to unmask")] + NoModel, + + #[error("too many models were aggregated for the current unmasking configuration")] + TooManyModels, + + #[error("the masked model is incompatible with the mask used for unmasking")] + MaskMismatch, + + #[error("the mask is invalid")] + InvalidMask, +} + +#[derive(Debug, Error)] +pub enum AggregationError { + #[error("the model to aggregate is invalid")] + InvalidModel, + + #[error("too many models were aggregated for the current unmasking configuration")] + TooManyModels, + + #[error("the model to aggregate is incompatible with the current aggregated model")] + ModelMismatch, +} + +impl AggregatedModel { + pub fn new(config: MaskConfig) -> Self { + Self { + nb_models: 0, + model: MaskObject { + data: vec![], + config: config.clone(), + }, + config, + } + } + + pub fn validate_unmasking(&self, mask: &Mask) -> Result<(), UnmaskingError> { + // We cannot perform unmasking without at least one real model + if self.nb_models == 0 { + return Err(UnmaskingError::NoModel); + } + + if self.nb_models > self.config.model_type.nb_models_max() { + return Err(UnmaskingError::TooManyModels); + } + + if self.config != mask.config || self.model.data.len() != mask.data.len() { + return Err(UnmaskingError::MaskMismatch); + } + + if !mask.is_valid() { + return Err(UnmaskingError::InvalidMask); + } + + Ok(()) + } + + pub fn unmask(mut self, mask: Mask) -> Model { + let scaled_add_shift = self.config.add_shift() * BigInt::from(self.nb_models); + let exp_shift = self.config.exp_shift(); + let order = self.config.order(); + self.model + .data + .drain(..) + .into_iter() + .zip(mask.data.into_iter()) + .map(|(masked_weight, mask)| { + // PANIC_SAFE: The substraction panics if it + // underflows, which can only happen if: + // + // mask > self.config.order() + // + // If the mask is valid, we are guaranteed that this + // cannot happen. Thus this method may panic only if + // given an invalid mask. + let n = (masked_weight + order.clone() - mask) % order.clone(); + + // UNWRAP_SAFE: to_bigint never fails for BigUint + let ratio = Ratio::::from(n.to_bigint().unwrap()); + + ratio / exp_shift.clone() - &scaled_add_shift + }) + .collect() + } + + pub fn validate_aggregation(&self, model: &MaskedModel) -> Result<(), AggregationError> { + if self.config != model.config || self.model.data.len() != model.data.len() { + return Err(AggregationError::ModelMismatch); + } + + if self.nb_models == self.config.model_type.nb_models_max() { + return Err(AggregationError::TooManyModels); + } + + if !model.is_valid() { + return Err(AggregationError::InvalidModel); + } + + Ok(()) + } + + pub fn aggregate(&mut self, model: MaskedModel) { + if self.nb_models == 0 { + self.model = model; + self.nb_models = 1; + return; + } + + let order = self.config.order(); + for (i, j) in self.model.data.iter_mut().zip(model.data.into_iter()) { + *i = (i.clone() + j) % order.clone() + } + self.nb_models += 1; + } +} + +pub struct Masker { + pub config: MaskConfig, + pub seed: MaskSeed, +} + +impl Masker { + pub fn new(config: MaskConfig) -> Self { + Self { + config, + seed: MaskSeed::generate(), + } + } + + pub fn with_seed(config: MaskConfig, seed: MaskSeed) -> Self { + Self { config, seed } + } +} + +impl Masker { + /// Mask the model wrt the mask configuration. Enforces bounds on the scalar and weights. + /// + /// The masking proceeds in the following steps: + /// - clamp the scalar and the weights according to the mask configuration + /// - shift the weights into the non-negative reals + /// - shift the weights into the non-negative integers + /// - shift the weights into the finite group + /// - mask the weights with random elements from the finite group + /// + /// The random elements are derived from a seeded PRNG. Unmasking proceeds in reverse order. For + /// a more detailes see [the confluence page](https://xainag.atlassian.net/wiki/spaces/FP/pages/542408769/Masking). + pub fn mask_model(self, scalar: f64, model: Model) -> (MaskSeed, MaskedModel) { + let random_ints = self.random_ints(); + let Self { seed, config } = self; + let mut masked_model = MaskedModel { + config: config.clone(), + data: vec![], + }; + masked_model.data = model + .into_iter() + .zip(random_ints) + .map(move |(weight, rand_int)| { + let scalar = Ratio::::from_float(clamp(scalar, 0_f64, 1_f64)).unwrap(); + let negative_bound = -config.add_shift(); + let positive_bound = config.add_shift(); + let scaled = scalar * clamp(weight, negative_bound.clone(), positive_bound.clone()); + // PANIC_SAFE: shifted weight is guaranteed to be non-negative + let shifted = ((scaled + config.add_shift()) * config.exp_shift()) + .to_integer() + .to_biguint() + .unwrap(); + (shifted + rand_int) % config.order() + }) + .collect(); + (seed, masked_model) + } + + fn random_ints(&self) -> impl Iterator { + let order = self.config.order(); + let mut prng = ChaCha20Rng::from_seed(self.seed.as_array()); + + iter::from_fn(move || Some(generate_integer(&mut prng, &order))) + } +} + +/// Generate a secure pseudo-random integer. Draws from a uniform distribution over the integers +/// between zero (included) and `max_int` (excluded). +pub fn generate_integer(prng: &mut ChaCha20Rng, max_int: &BigUint) -> BigUint { + if max_int.is_zero() { + return BigUint::zero(); + } + let mut bytes = max_int.to_bytes_le(); + let mut rand_int = max_int.clone(); + while rand_int >= *max_int { + prng.fill_bytes(&mut bytes); + rand_int = BigUint::from_bytes_le(&bytes); + } + rand_int +} + +#[cfg(test)] +mod tests { + use std::{convert::TryFrom, iter}; + + use rand::{ + distributions::{Distribution, Uniform}, + SeedableRng, + }; + use rand_chacha::ChaCha20Rng; + + use super::*; + use crate::{ + crypto::generate_integer, + mask::config::{BoundType, DataType, GroupType, MaskConfig, ModelType}, + model::MaskModels, + }; + + #[test] + fn test_aggregation() { + let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); + let config = MaskConfig { + group_type: GroupType::Prime, + data_type: DataType::F32, + bound_type: BoundType::B0, + model_type: ModelType::M3, + }; + let integers = iter::repeat_with(|| generate_integer(&mut prng, config.order())) + .take(10) + .collect(); + let other_integers = iter::repeat_with(|| generate_integer(&mut prng, config.order())) + .take(10) + .collect(); + let masked_model = AggregatedModel::from_parts(integers, config.clone()).unwrap(); + let other_masked_model = + AggregatedModel::from_parts(other_integers, config.clone()).unwrap(); + let aggregated_masked_model = masked_model.aggregate(&other_masked_model).unwrap(); + assert_eq!( + aggregated_masked_model.integers().len(), + masked_model.integers().len(), + ); + assert_eq!( + aggregated_masked_model.integers().len(), + other_masked_model.integers().len(), + ); + assert_eq!(aggregated_masked_model.config(), &config); + assert!(aggregated_masked_model + .integers() + .iter() + .all(|integer| integer < config.order())); + } + + #[test] + fn test_masking_and_aggregation() { + let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); + let uniform = Uniform::new(-1_f32, 1_f32); + let weights = iter::repeat_with(|| uniform.sample(&mut prng)) + .take(10) + .collect::>(); + let other_weights = iter::repeat_with(|| uniform.sample(&mut prng)) + .take(10) + .collect::>(); + let model = Model::try_from(weights).unwrap(); + let other_model = Model::try_from(other_weights).unwrap(); + let config = MaskConfig { + group_type: GroupType::Prime, + data_type: DataType::F32, + bound_type: BoundType::B0, + model_type: ModelType::M3, + }; + let (mask_seed, masked_model) = model.mask(0.5_f64, &config); + let (other_mask_seed, other_masked_model) = other_model.mask(0.5_f64, &config); + let aggregated_masked_model = masked_model.aggregate(&other_masked_model).unwrap(); + let aggregated_mask = mask_seed + .derive_mask(10, &config) + .aggregate(&other_mask_seed.derive_mask(10, &config)) + .unwrap(); + let aggregated_model: Model = + aggregated_masked_model.unmask(&aggregated_mask, 2).unwrap(); + let averaged_weights = model + .weights() + .iter() + .zip(other_model.weights().iter()) + .map(|(weight, other_weight)| 0.5 * weight + 0.5 * other_weight) + .collect::>(); + assert!(aggregated_model + .weights() + .iter() + .zip(averaged_weights.iter()) + .all( + |(aggregated_weight, averaged_weight)| (aggregated_weight - averaged_weight).abs() + < 1e-8_f32 + )); + } + + #[test] + fn test_serialization() { + let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); + let config = MaskConfig { + group_type: GroupType::Prime, + data_type: DataType::F32, + bound_type: BoundType::B0, + model_type: ModelType::M3, + }; + let integers = iter::repeat_with(|| generate_integer(&mut prng, config.order())) + .take(10) + .collect(); + let masked_model = AggregatedModel::from_parts(integers, config).unwrap(); + assert_eq!(masked_model.len(), 64); + let serialized = masked_model.serialize(); + assert_eq!(serialized.len(), 64); + let deserialized = AggregatedModel::deserialize(serialized.as_slice()).unwrap(); + assert_eq!(masked_model, deserialized); + } + + #[test] + fn test_floats_from() { + // f32 + let ratio = vec![Ratio::from_float(0_f32).unwrap()]; + assert_eq!(floats_from::(ratio), vec![0_f32]); + let ratio = vec![Ratio::from_float(0.1_f32).unwrap()]; + assert_eq!(floats_from::(ratio), vec![0.1_f32]); + let ratio = vec![ + (Ratio::from_float(f32::max_value()).unwrap() * BigInt::from(10_usize)) + / (Ratio::from_float(f32::max_value()).unwrap() * BigInt::from(100_usize)), + ]; + assert_eq!(floats_from::(ratio), vec![0.1_f32]); + + // f64 + let ratio = vec![Ratio::from_float(0_f64).unwrap()]; + assert_eq!(floats_from::(ratio), vec![0_f64]); + let ratio = vec![Ratio::from_float(0.1_f64).unwrap()]; + assert_eq!(floats_from::(ratio), vec![0.1_f64]); + let ratio = vec![ + (Ratio::from_float(f64::max_value()).unwrap() * BigInt::from(10_usize)) + / (Ratio::from_float(f64::max_value()).unwrap() * BigInt::from(100_usize)), + ]; + assert_eq!(floats_from::(ratio), vec![0.1_f64]); + } +} diff --git a/rust/src/mask/mod.rs b/rust/src/mask/mod.rs index 0f43622e1..e1cc8335e 100644 --- a/rust/src/mask/mod.rs +++ b/rust/src/mask/mod.rs @@ -1,430 +1,13 @@ -pub mod config; -pub mod seed; - -use std::convert::TryInto; - -use num::{ - bigint::{BigInt, BigUint, ToBigInt}, - rational::Ratio, - traits::{cast::ToPrimitive, float::FloatCore}, +mod config; +mod mask_object; +mod masking; +mod model; +mod seed; + +pub use self::{ + config::{BoundType, DataType, GroupType, MaskConfig, ModelType}, + mask_object::MaskObject, + masking::{AggregatedModel, AggregationError, Mask, MaskedModel, Masker, UnmaskingError}, + model::{FromPrimitives, IntoPrimitives, Model}, + seed::{EncryptedMaskSeed, MaskSeed}, }; - -use self::config::{MaskConfig, ModelType}; -use crate::{model::Model, PetError}; - -#[allow(clippy::len_without_is_empty)] -/// Aggregation and serialization of vectors of arbitrarily large integers. -pub trait Integers: Sized { - type Error; - - define_trait_fields!( - integers, Vec; - config, MaskConfig; - ); - - /// Get an error value of the error type to be used in the default implementations. - fn error_value() -> Self::Error; - - /// Create the object from its parts. Fails if the integers don't conform to the mask - /// configuration. - fn from_parts(integers: Vec, config: MaskConfig) -> Result; - - /// Get the length of the serialized object. - fn len(&self) -> usize { - 4 + self.integers().len() * self.config().element_len() - } - - /// Serialize the object into bytes. - fn serialize(&self) -> Vec { - let element_len = self.config().element_len(); - let bytes = self - .integers() - .iter() - .flat_map(|integer| { - let mut bytes = integer.to_bytes_le(); - bytes.resize(element_len, 0_u8); - bytes - }) - .collect(); - [self.config().serialize(), bytes].concat() - } - - /// Deserialize the object from bytes. Fails if the bytes don't conform to the mask - /// configuration. - fn deserialize(bytes: &[u8]) -> Result { - if bytes.len() < 4 { - return Err(Self::error_value()); - } - let config = MaskConfig::deserialize(&bytes[..4]).or_else(|_| Err(Self::error_value()))?; - let element_len = config.element_len(); - if bytes[4..].len() % element_len != 0 { - return Err(Self::error_value()); - } - let integers = bytes[4..] - .chunks_exact(element_len) - .map(|chunk| BigUint::from_bytes_le(chunk)) - .collect::>(); - Self::from_parts(integers, config) - } - - /// Aggregate the object with another one. Fails if the mask configurations or the integer - /// lengths don't conform. - fn aggregate(&self, other: &Self) -> Result { - if self.integers().len() == other.integers().len() && self.config() == other.config() { - let aggregated_integers = self - .integers() - .iter() - .zip(other.integers().iter()) - .map(|(integer, other_integer)| (integer + other_integer) % self.config().order()) - .collect(); - Self::from_parts(aggregated_integers, self.config().clone()) - } else { - Err(Self::error_value()) - } - } -} - -/// Unmasking of vectors of arbitrarily large integers. -pub trait MaskIntegers: Integers { - /// Unmask the masked model with a mask. Fails if the mask configuration is violated. - fn unmask(&self, mask: &Mask, no_models: usize) -> Result, Self::Error>; - - /// Cast the ratios as numbers. Fails if a ratio is not representable as `N`. - fn numbers_from(ratios: Vec>) -> Option>; - - /// Unmask the masked numbers with a mask. Fails if the mask configuration is violated. - fn unmask_numbers(&self, mask: &Mask, no_models: usize) -> Result, Self::Error> { - let max_models = match self.config().name().model_type() { - ModelType::M3 => 1_000, - ModelType::M6 => 1_000_000, - ModelType::M9 => 1_000_000_000, - ModelType::M12 => 1_000_000_000_000, - }; - if no_models > 0 - && no_models <= max_models - && self.integers().len() == mask.integers().len() - && self.config() == mask.config() - { - let scaled_add_shift = self.config().add_shift() * BigInt::from(no_models); - let ratios = self - .integers() - .iter() - .zip(mask.integers().iter()) - .map(|(masked_weight, mask)| { - let unmasked = Ratio::::from( - ((masked_weight + self.config().order() - mask) % self.config().order()) - .to_bigint() - // safe unwrap: `to_bigint` never fails for `BigUint`s - .unwrap(), - ); - unmasked / self.config().exp_shift() - &scaled_add_shift - }) - .collect::>>(); - Self::numbers_from(ratios).ok_or_else(Self::error_value) - } else { - Err(Self::error_value()) - } - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] -/// A masked model. Its parameters are represented as a vector of integers from a finite group wrt -/// a mask configuration. -pub struct MaskedModel { - integers: Vec, - config: MaskConfig, -} - -impl Integers for MaskedModel { - type Error = PetError; - - derive_trait_fields!( - integers, Vec; - config, MaskConfig; - ); - - /// Get an error value of the error type to be used in the default implementations. - fn error_value() -> Self::Error { - Self::Error::InvalidModel - } - - /// Create a masked model from its parts. Fails if the integers don't conform to the mask - /// configuration. - fn from_parts(integers: Vec, config: MaskConfig) -> Result { - if integers.iter().all(|integer| integer < config.order()) { - Ok(Self { integers, config }) - } else { - Err(Self::Error::InvalidModel) - } - } -} - -impl MaskIntegers for MaskedModel { - /// Unmask the masked model with a mask. Fails if the mask configuration is violated. - fn unmask(&self, mask: &Mask, no_models: usize) -> Result, PetError> { - >::unmask_numbers(&self, mask, no_models)?.try_into() - } - - /// Cast the ratios as numbers. Fails if a ratio is not representable as `f32`. - fn numbers_from(ratios: Vec>) -> Option> { - floats_from(ratios) - } -} - -impl MaskIntegers for MaskedModel { - /// Unmask the masked model with a mask. Fails if the mask configuration is violated. - fn unmask(&self, mask: &Mask, no_models: usize) -> Result, PetError> { - >::unmask_numbers(&self, mask, no_models)?.try_into() - } - - /// Cast the ratios as numbers. Fails if a ratio is not representable as `f64`. - fn numbers_from(ratios: Vec>) -> Option> { - floats_from(ratios) - } -} - -impl MaskIntegers for MaskedModel { - /// Unmask the masked model with a mask. Fails if the mask configuration is violated. - fn unmask(&self, mask: &Mask, no_models: usize) -> Result, PetError> { - Ok(>::unmask_numbers(&self, mask, no_models)?.into()) - } - - /// Cast the ratios as numbers. Fails if a ratio is not representable as `i32`. - fn numbers_from(ratios: Vec>) -> Option> { - ratios - .iter() - .map(|ratio| (ratio.to_integer().to_i32())) - .collect() - } -} - -impl MaskIntegers for MaskedModel { - /// Unmask the masked model with a mask. Fails if the mask configuration is violated. - fn unmask(&self, mask: &Mask, no_models: usize) -> Result, PetError> { - Ok(>::unmask_numbers(&self, mask, no_models)?.into()) - } - - /// Cast the ratios as numbers. Fails if a ratio is not representable as `i64`. - fn numbers_from(ratios: Vec>) -> Option> { - ratios - .iter() - .map(|ratio| (ratio.to_integer().to_i64())) - .collect() - } -} - -/// Cast the ratios as floats. Fails if a ratio is not representable as `F`. -fn floats_from(ratios: Vec>) -> Option> { - // safe unwraps: values are finite - let min_value = &Ratio::from_float(F::min_value()).unwrap(); - let max_value = &Ratio::from_float(F::max_value()).unwrap(); - ratios - .iter() - .map(|ratio| { - if min_value <= ratio && ratio <= max_value { - let mut numer = ratio.numer().clone(); - let mut denom = ratio.denom().clone(); - // safe loop: terminates after at most bit-length of ratio iterations - loop { - if let (Some(n), Some(d)) = (F::from(numer.clone()), F::from(denom.clone())) { - if n == F::zero() || d == F::zero() { - break Some(F::zero()); - } else { - let float = n / d; - if float.is_finite() { - break Some(float); - } - } - } else { - numer >>= 1_usize; - denom >>= 1_usize; - } - } - } else { - None - } - }) - .collect() -} - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -/// A mask. Its parameters are represented as a vector of integers from a finite group wrt a mask -/// configuration. -pub struct Mask { - integers: Vec, - config: MaskConfig, -} - -impl Integers for Mask { - type Error = PetError; - - derive_trait_fields!( - integers, Vec; - config, MaskConfig; - ); - - /// Get an error value of the error type to be used in the default implementations. - fn error_value() -> Self::Error { - Self::Error::InvalidMask - } - - /// Create a mask from its parts. Fails if the integers don't conform to the mask configuration. - fn from_parts(integers: Vec, config: MaskConfig) -> Result { - if integers.iter().all(|integer| integer < config.order()) { - Ok(Self { integers, config }) - } else { - Err(Self::Error::InvalidMask) - } - } -} - -#[cfg(test)] -mod tests { - use std::{convert::TryFrom, iter}; - - use rand::{ - distributions::{Distribution, Uniform}, - SeedableRng, - }; - use rand_chacha::ChaCha20Rng; - - use super::*; - use crate::{ - crypto::generate_integer, - mask::config::{BoundType, DataType, GroupType, MaskConfigs, ModelType}, - model::MaskModels, - }; - - #[test] - fn test_aggregation() { - let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); - let config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); - let integers = iter::repeat_with(|| generate_integer(&mut prng, config.order())) - .take(10) - .collect(); - let other_integers = iter::repeat_with(|| generate_integer(&mut prng, config.order())) - .take(10) - .collect(); - let masked_model = MaskedModel::from_parts(integers, config.clone()).unwrap(); - let other_masked_model = MaskedModel::from_parts(other_integers, config.clone()).unwrap(); - let aggregated_masked_model = masked_model.aggregate(&other_masked_model).unwrap(); - assert_eq!( - aggregated_masked_model.integers().len(), - masked_model.integers().len(), - ); - assert_eq!( - aggregated_masked_model.integers().len(), - other_masked_model.integers().len(), - ); - assert_eq!(aggregated_masked_model.config(), &config); - assert!(aggregated_masked_model - .integers() - .iter() - .all(|integer| integer < config.order())); - } - - #[test] - fn test_masking_and_aggregation() { - let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); - let uniform = Uniform::new(-1_f32, 1_f32); - let weights = iter::repeat_with(|| uniform.sample(&mut prng)) - .take(10) - .collect::>(); - let other_weights = iter::repeat_with(|| uniform.sample(&mut prng)) - .take(10) - .collect::>(); - let model = Model::try_from(weights).unwrap(); - let other_model = Model::try_from(other_weights).unwrap(); - let config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); - let (mask_seed, masked_model) = model.mask(0.5_f64, &config); - let (other_mask_seed, other_masked_model) = other_model.mask(0.5_f64, &config); - let aggregated_masked_model = masked_model.aggregate(&other_masked_model).unwrap(); - let aggregated_mask = mask_seed - .derive_mask(10, &config) - .aggregate(&other_mask_seed.derive_mask(10, &config)) - .unwrap(); - let aggregated_model: Model = - aggregated_masked_model.unmask(&aggregated_mask, 2).unwrap(); - let averaged_weights = model - .weights() - .iter() - .zip(other_model.weights().iter()) - .map(|(weight, other_weight)| 0.5 * weight + 0.5 * other_weight) - .collect::>(); - assert!(aggregated_model - .weights() - .iter() - .zip(averaged_weights.iter()) - .all( - |(aggregated_weight, averaged_weight)| (aggregated_weight - averaged_weight).abs() - < 1e-8_f32 - )); - } - - #[test] - fn test_serialization() { - let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); - let config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); - let integers = iter::repeat_with(|| generate_integer(&mut prng, config.order())) - .take(10) - .collect(); - let masked_model = MaskedModel::from_parts(integers, config).unwrap(); - assert_eq!(masked_model.len(), 64); - let serialized = masked_model.serialize(); - assert_eq!(serialized.len(), 64); - let deserialized = MaskedModel::deserialize(serialized.as_slice()).unwrap(); - assert_eq!(masked_model, deserialized); - } - - #[test] - fn test_floats_from_f32() { - let ratio = vec![Ratio::from_float(0_f32).unwrap()]; - assert_eq!(floats_from::(ratio).unwrap(), vec![0_f32]); - let ratio = vec![Ratio::from_float(0.1_f32).unwrap()]; - assert_eq!(floats_from::(ratio).unwrap(), vec![0.1_f32]); - let ratio = vec![ - (Ratio::from_float(f32::MAX).unwrap() * BigInt::from(10)) - / (Ratio::from_float(f32::MAX).unwrap() * BigInt::from(100)), - ]; - assert_eq!(floats_from::(ratio).unwrap(), vec![0.1_f32]); - let ratio = vec![Ratio::from_float(f32::MAX).unwrap() + BigInt::from(1)]; - floats_from::(ratio).unwrap_none(); - let ratio = vec![Ratio::from_float(f32::MIN).unwrap() - BigInt::from(1)]; - floats_from::(ratio).unwrap_none(); - } - - #[test] - fn test_floats_from_f64() { - let ratio = vec![Ratio::from_float(0_f64).unwrap()]; - assert_eq!(floats_from::(ratio).unwrap(), vec![0_f64]); - let ratio = vec![Ratio::from_float(0.1_f64).unwrap()]; - assert_eq!(floats_from::(ratio).unwrap(), vec![0.1_f64]); - let ratio = vec![ - (Ratio::from_float(f64::MAX).unwrap() * BigInt::from(10)) - / (Ratio::from_float(f64::MAX).unwrap() * BigInt::from(100)), - ]; - assert_eq!(floats_from::(ratio).unwrap(), vec![0.1_f64]); - let ratio = vec![Ratio::from_float(f64::MAX).unwrap() + BigInt::from(1)]; - floats_from::(ratio).unwrap_none(); - let ratio = vec![Ratio::from_float(f64::MIN).unwrap() - BigInt::from(1)]; - floats_from::(ratio).unwrap_none(); - } -} diff --git a/rust/src/mask/model.rs b/rust/src/mask/model.rs new file mode 100644 index 000000000..d7ab9c18a --- /dev/null +++ b/rust/src/mask/model.rs @@ -0,0 +1,531 @@ +use std::{ + fmt::Debug, + iter::{FromIterator, IntoIterator}, +}; + +use derive_more::{AsMut, AsRef, Display, From, Index, IndexMut, Into}; +use num::{ + bigint::BigInt, + clamp, + rational::Ratio, + traits::{float::FloatCore, identities::Zero, ToPrimitive}, +}; +use thiserror::Error; + +/// Represent a model. +#[derive(Debug, Clone, AsRef, AsMut, PartialEq, Hash, From, Index, IndexMut, Into)] +pub struct Model(Vec>); + +impl FromIterator> for Model { + fn from_iter>>(iter: I) -> Self { + let data: Vec> = iter.into_iter().collect(); + Model(data) + } +} + +impl IntoIterator for Model { + type Item = Ratio; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +#[derive(Debug, Display)] +enum PrimitiveType { + F32, + F64, + I32, + I64, +} + +#[derive(Error, Debug)] +#[error("Could not convert weight {weight} to primitive type {target}")] +pub struct ModelCastError { + weight: Ratio, + target: PrimitiveType, +} + +#[derive(Error, Debug)] +#[error("Could not convert primitive type {0:?} to model weight")] +pub struct PrimitiveCastError(P); + +/// Convert this type into a an iterator of type `P`. This trait is +/// used to convert a [`Model`], which has its own internal +/// representation o the weights into primitive types (`f64`, `f32`, +/// `i32` `i64`). +pub trait IntoPrimitives: Sized { + /// Consume this model and into an iterator that yields `Ok(P)` + /// for each model weight that can be converted to `P`, and + /// `Err(ModelCastError)` for each weight that cannot be converted + /// to `P`. + fn into_primitives(self) -> Box>>; + + /// Consume this model and into an iterator that yields `P` values. + /// + /// # Panics + /// + /// This method panics if a model weight cannot be converted into + /// `P`. + fn into_primitives_unchecked(self) -> Box> { + Box::new( + self.into_primitives() + .map(|res| res.expect("conversion to primitive type failed")), + ) + } +} + +pub trait FromPrimitives: Sized { + /// Consume an iterator that yields `P`, into a model. If a `P` + /// cannot be converted to a model weight, this method fails. + fn from_primitives>(iter: I) -> Result>; + + /// Consume an iterator that yields `P` values into a model. If a + /// `P` cannot be directly converted into a model weight because it is not finite, it is clamped. + /// clamped. + fn from_primitives_bounded>(iter: I) -> Self; +} + +impl IntoPrimitives for Model { + fn into_primitives(self) -> Box>> { + Box::new(self.0.into_iter().map(|i| { + i.to_integer().to_i32().ok_or_else(|| ModelCastError { + weight: i, + target: PrimitiveType::I32, + }) + })) + } +} + +impl FromPrimitives for Model { + fn from_primitives>(iter: I) -> Result> { + Ok(iter.map(|p| Ratio::from_integer(BigInt::from(p))).collect()) + } + + fn from_primitives_bounded>(iter: I) -> Self { + Self::from_primitives(iter).unwrap() + } +} + +impl IntoPrimitives for Model { + fn into_primitives(self) -> Box>> { + Box::new(self.0.into_iter().map(|i| { + i.to_integer().to_i64().ok_or_else(|| ModelCastError { + weight: i, + target: PrimitiveType::I64, + }) + })) + } +} + +impl FromPrimitives for Model { + fn from_primitives>(iter: I) -> Result> { + Ok(iter.map(|p| Ratio::from_integer(BigInt::from(p))).collect()) + } + + fn from_primitives_bounded>(iter: I) -> Self { + Self::from_primitives(iter).unwrap() + } +} + +impl IntoPrimitives for Model { + fn into_primitives(self) -> Box>> { + let iter = self.0.into_iter().map(|r| { + ratio_to_float::(&r).ok_or_else(|| ModelCastError { + weight: r, + target: PrimitiveType::F32, + }) + }); + Box::new(iter) + } +} + +impl FromPrimitives for Model { + fn from_primitives>(iter: I) -> Result> { + iter.map(|f| float_to_ratio::(f).ok_or_else(|| PrimitiveCastError(f))) + .collect() + } + + fn from_primitives_bounded>(iter: I) -> Self { + iter.map(float_to_ratio_bounded::).collect() + } +} + +impl IntoPrimitives for Model { + fn into_primitives(self) -> Box>> { + let iter = self.0.into_iter().map(|r| { + ratio_to_float::(&r).ok_or_else(|| ModelCastError { + weight: r, + target: PrimitiveType::F64, + }) + }); + Box::new(iter) + } +} + +impl FromPrimitives for Model { + fn from_primitives>(iter: I) -> Result> { + iter.map(|f| float_to_ratio::(f).ok_or_else(|| PrimitiveCastError(f))) + .collect() + } + + fn from_primitives_bounded>(iter: I) -> Self { + iter.map(float_to_ratio_bounded::).collect() + } +} + +fn ratio_to_float(ratio: &Ratio) -> Option { + let min_value = Ratio::from_float(F::min_value()).unwrap(); + let max_value = Ratio::from_float(F::max_value()).unwrap(); + + if ratio < &min_value || ratio > &max_value { + return None; + } + + let mut numer = ratio.numer().clone(); + let mut denom = ratio.denom().clone(); + // safe loop: terminates after at most bit-length of ratio iterations + let f = loop { + if let (Some(n), Some(d)) = (F::from(numer.clone()), F::from(denom.clone())) { + if n == F::zero() || d == F::zero() { + break F::zero(); + } else { + let f = n / d; + if f.is_finite() { + break f; + } + } + } else { + numer >>= 1_usize; + denom >>= 1_usize; + } + }; + Some(f) +} + +/// Cast the given float to a ratio. Positive/negative infinity is +/// mapped to max/min and NaN to zero. +fn float_to_ratio_bounded(f: F) -> Ratio { + if f.is_nan() { + Ratio::::zero() + } else { + let finite_f = clamp(f.clone(), F::min_value(), F::max_value()); + // safe unwrap: clamped weight is guaranteed to be finite + Ratio::::from_float(finite_f).unwrap() + } +} + +/// Cast the given float to a ratio. If `f` cannot be converted to a +/// ratio, `None` is returned. +fn float_to_ratio(f: F) -> Option> { + if f.is_nan() || f.is_finite() { + None + } else { + Some(float_to_ratio_bounded(f)) + } +} + +#[cfg(test)] +mod tests { + use std::iter; + + use rand::distributions::{Distribution, Uniform}; + + use super::*; + use crate::mask::{ + config::{ + BoundType::{Bmax, B0, B2, B4, B6}, + DataType::{F32, F64, I32, I64}, + GroupType::{Integer, Power2, Prime}, + MaskConfigs, + ModelType::{M12, M3, M6, M9}, + }, + seed::MaskSeed, + MaskIntegers, + }; + + type R = Ratio; + + #[test] + fn test_model_f32() { + let expected_primitives = vec![-1_f32, 0_f32, 1_f32]; + let expected_model = vec![ + R::from_float(-1_f32).unwrap(), + R::zero(), + R::from_float(1_f32).unwrap(), + ]; + + let actual_model = Model::from_primitives(expected_primitives.iter().cloned()); + assert_eq!(actual_model, expected_model); + + let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned()); + assert_eq!(actual_model, expected_model); + + let actual_primitives = expected_model.into_primitives_unchecked().collect(); + assert_eq!(actual_primitives, expected_primitives); + } + + #[test] + fn test_model_f64() { + let expected_primitives = vec![-1_f64, 0_f64, 1_f64]; + let expected_model = vec![ + R::from_float(-1_f64).unwrap(), + R::zero(), + R::from_float(1_f64).unwrap(), + ]; + + let actual_model = Model::from_primitives(expected_primitives.iter().cloned()); + assert_eq!(actual_model, expected_model); + + let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned()); + assert_eq!(actual_model, expected_model); + + let actual_primitives = expected_model.into_primitives_unchecked().collect(); + assert_eq!(actual_primitives, expected_primitives); + } + + #[test] + fn test_model_f32_from_weird_primitives() { + // +infinity + assert!(Model::from_primitives(iter::once(f32::INFINITY).is_err())); + assert!( + Model::from_primitives_bounded(iter::once(f32::INFINITY)), + vec![R::from_float(f32::MAX).unwrap()].into() + ); + + // -infinity + assert!( + Model::from_primitives_bounded(iter::once(f32::NEG_INFINITY)), + vec![R::from_float(f32::MIN).unwrap()].into() + ); + assert!(Model::from_primitives( + iter::once(f32::NEG_INFINITY).is_err() + )); + + // NaN + assert!(Model::from_primitives(iter::once(f32::NAN).is_err())); + assert!( + Model::from_primitives_bounded(iter::once(f32::NAN)), + vec![R::from_float(0).unwrap()].into() + ); + } + + #[test] + fn test_model_f64_from_weird_primitives() { + // +infinity + assert!(Model::from_primitives(iter::once(f64::INFINITY).is_err())); + assert!( + Model::from_primitives_bounded(iter::once(f64::INFINITY)), + vec![R::from_float(f64::MAX).unwrap()].into() + ); + + // -infinity + assert!( + Model::from_primitives_bounded(iter::once(f64::NEG_INFINITY)), + vec![R::from_float(f64::MIN).unwrap()].into() + ); + assert!(Model::from_primitives( + iter::once(f64::NEG_INFINITY).is_err() + )); + + // NaN + assert!(Model::from_primitives(iter::once(f64::NAN).is_err())); + assert!( + Model::from_primitives_bounded(iter::once(f64::NAN)), + vec![R::from_float(0).unwrap()].into() + ); + } + + #[test] + fn test_model_i32() { + let expected_primitives = vec![-1_i32, 0_i32, 1_i32]; + let expected_model = vec![ + R::from_integer(BigInt::from(-1_i32)).unwrap(), + R::zero(), + R::from_float(BigInt::from(1_i32)).unwrap(), + ]; + + let actual_model = Model::from_primitives(expected_primitives.iter().cloned()); + assert_eq!(actual_model, expected_model); + + let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned()); + assert_eq!(actual_model, expected_model); + + let actual_primitives = expected_model.into_primitives_unchecked().collect(); + assert_eq!(actual_primitives, expected_primitives); + } + + #[test] + fn test_model_i64() { + let expected_primitives = vec![-1_i64, 0_i64, 1_i64]; + let expected_model = vec![ + R::from_integer(BigInt::from(-1_i64)).unwrap(), + R::zero(), + R::from_float(BigInt::from(1_i64)).unwrap(), + ]; + + let actual_model = Model::from_primitives(expected_primitives.iter().cloned()); + assert_eq!(actual_model, expected_model); + + let actual_model = Model::from_primitives_bounded(expected_primitives.iter().cloned()); + assert_eq!(actual_model, expected_model); + + let actual_primitives = expected_model.into_primitives_unchecked().collect(); + assert_eq!(actual_primitives, expected_primitives); + } + + // /// Generate tests for masking and unmasking. The tests proceed in the following steps: + // /// - generate random weights from a uniform distribution with a seeded PRNG + // /// - create a model from the weights and mask it + // /// - check that all masked weights belong to the chosen finite group + // /// - unmask the masked model + // /// - check that all unmasked weights are equal to the original weights (up to a tolerance) + // /// + // /// The arguments to the macro are: + // /// - a suffix for the test name + // /// - the group type of the model (variants of `GroupType`) + // /// - the data type of the model (either primitives or variants of `DataType`) + // /// - an absolute bound for the weights (optional, choices: 1, 100, 10_000, 1_000_000) + // /// - the number of weights + // /// - a tolerance for the equality check + // /// + // /// For float data types the error depends on the order of magnitude of the weights, therefore + // /// it may be necessary to raise the tolerance or bound the random weights if this test fails. + // macro_rules! test_masking { + // ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr, $tol:expr $(,)?) => { + // paste::item! { + // #[test] + // fn []() { + // paste::expr! { + // let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); + // let uniform = if $bound == 0 { + // Uniform::new([<$data:lower>]::MIN, [<$data:lower>]::MAX) + // } else { + // Uniform::new(-$bound as [<$data:lower>], $bound as [<$data:lower>]) + // }; + // let weights = iter::repeat_with(|| uniform.sample(&mut prng)) + // .take($len as usize) + // .collect::>(); + // let model = Model::try_from(weights).unwrap(); + // let bound_type = match $bound { + // 1 => B0, + // 100 => B2, + // 10_000 => B4, + // 1_000_000 => B6, + // 0 => Bmax, + // _ => panic!("Unknown bound!") + // }; + // let config = MaskConfigs::from_parts( + // $group, + // [<$data:upper>], + // bound_type, + // M3 + // ).config(); + // let (mask_seed, masked_model) = model.mask(1_f64, &config); + // assert_eq!( + // masked_model.integers().len(), + // $len as usize + // ); + // assert!( + // masked_model + // .integers() + // .iter() + // .all(|integer| integer < config.order()) + // ); + // let mask = mask_seed.derive_mask($len as usize, &config); + // let unmasked_model: Model<[<$data:lower>]> = masked_model + // .unmask(&mask, 1_usize) + // .unwrap(); + // assert!( + // model + // .weights() + // .iter() + // .zip(unmasked_model.weights().iter()) + // .all( + // |(weight, unmasked_weight)| + // (weight - unmasked_weight).abs() + // <= $tol as [<$data:lower>] + // ) + // ); + // } + // } + // } + // }; + // ($suffix:ident, $group:ty, $data:ty, $len:expr, $tol:expr $(,)?) => { + // test_masking!($suffix, $group, $data, 0, $len, $tol); + // }; + // } + + // test_masking!(int_f32_b0, Integer, f32, 1, 10, 1e-3); + // test_masking!(int_f32_b2, Integer, f32, 100, 10, 1e-3); + // test_masking!(int_f32_b4, Integer, f32, 10_000, 10, 1e-3); + // test_masking!(int_f32_b6, Integer, f32, 1_000_000, 10, 1e-3); + // test_masking!(int_f32_bmax, Integer, f32, 10, 1e-3); + + // test_masking!(prime_f32_b0, Prime, f32, 1, 10, 1e-3); + // test_masking!(prime_f32_b2, Prime, f32, 100, 10, 1e-3); + // test_masking!(prime_f32_b4, Prime, f32, 10_000, 10, 1e-3); + // test_masking!(prime_f32_b6, Prime, f32, 1_000_000, 10, 1e-3); + // test_masking!(prime_f32_bmax, Prime, f32, 10, 1e-3); + + // test_masking!(pow_f32_b0, Power2, f32, 1, 10, 1e-3); + // test_masking!(pow_f32_b2, Power2, f32, 100, 10, 1e-3); + // test_masking!(pow_f32_b4, Power2, f32, 10_000, 10, 1e-3); + // test_masking!(pow_f32_b6, Power2, f32, 1_000_000, 10, 1e-3); + // test_masking!(pow_f32_bmax, Power2, f32, 10, 1e-3); + + // test_masking!(int_f64_b0, Integer, f64, 1, 10, 1e-3); + // test_masking!(int_f64_b2, Integer, f64, 100, 10, 1e-3); + // test_masking!(int_f64_b4, Integer, f64, 10_000, 10, 1e-3); + // test_masking!(int_f64_b6, Integer, f64, 1_000_000, 10, 1e-3); + // test_masking!(int_f64_bmax, Integer, f64, 10, 1e-3); + + // test_masking!(prime_f64_b0, Prime, f64, 1, 10, 1e-3); + // test_masking!(prime_f64_b2, Prime, f64, 100, 10, 1e-3); + // test_masking!(prime_f64_b4, Prime, f64, 10_000, 10, 1e-3); + // test_masking!(prime_f64_b6, Prime, f64, 1_000_000, 10, 1e-3); + // test_masking!(prime_f64_bmax, Prime, f64, 10, 1e-3); + + // test_masking!(pow_f64_b0, Power2, f64, 1, 10, 1e-3); + // test_masking!(pow_f64_b2, Power2, f64, 100, 10, 1e-3); + // test_masking!(pow_f64_b4, Power2, f64, 10_000, 10, 1e-3); + // test_masking!(pow_f64_b6, Power2, f64, 1_000_000, 10, 1e-3); + // test_masking!(pow_f64_bmax, Power2, f64, 10, 1e-3); + + // test_masking!(int_i32_b0, Integer, i32, 1, 10, 1e-3); + // test_masking!(int_i32_b2, Integer, i32, 100, 10, 1e-3); + // test_masking!(int_i32_b4, Integer, i32, 10_000, 10, 1e-3); + // test_masking!(int_i32_b6, Integer, i32, 1_000_000, 10, 1e-3); + // test_masking!(int_i32_bmax, Integer, i32, 10, 1e-3); + + // test_masking!(prime_i32_b0, Prime, i32, 1, 10, 1e-3); + // test_masking!(prime_i32_b2, Prime, i32, 100, 10, 1e-3); + // test_masking!(prime_i32_b4, Prime, i32, 10_000, 10, 1e-3); + // test_masking!(prime_i32_b6, Prime, i32, 1_000_000, 10, 1e-3); + // test_masking!(prime_i32_bmax, Prime, i32, 10, 1e-3); + + // test_masking!(pow_i32_b0, Power2, i32, 1, 10, 1e-3); + // test_masking!(pow_i32_b2, Power2, i32, 100, 10, 1e-3); + // test_masking!(pow_i32_b4, Power2, i32, 10_000, 10, 1e-3); + // test_masking!(pow_i32_b6, Power2, i32, 1_000_000, 10, 1e-3); + // test_masking!(pow_i32_bmax, Power2, i32, 10, 1e-3); + + // test_masking!(int_i64_b0, Integer, i64, 1, 10, 1e-3); + // test_masking!(int_i64_b2, Integer, i64, 100, 10, 1e-3); + // test_masking!(int_i64_b4, Integer, i64, 10_000, 10, 1e-3); + // test_masking!(int_i64_b6, Integer, i64, 1_000_000, 10, 1e-3); + // test_masking!(int_i64_bmax, Integer, i64, 10, 1e-3); + + // test_masking!(prime_i64_b0, Prime, i64, 1, 10, 1e-3); + // test_masking!(prime_i64_b2, Prime, i64, 100, 10, 1e-3); + // test_masking!(prime_i64_b4, Prime, i64, 10_000, 10, 1e-3); + // test_masking!(prime_i64_b6, Prime, i64, 1_000_000, 10, 1e-3); + // test_masking!(prime_i64_bmax, Prime, i64, 10, 1e-3); + + // test_masking!(pow_i64_b0, Power2, i64, 1, 10, 1e-3); + // test_masking!(pow_i64_b2, Power2, i64, 100, 10, 1e-3); + // test_masking!(pow_i64_b4, Power2, i64, 10_000, 10, 1e-3); + // test_masking!(pow_i64_b6, Power2, i64, 1_000_000, 10, 1e-3); + // test_masking!(pow_i64_bmax, Power2, i64, 10, 1e-3); +} diff --git a/rust/src/mask/seed.rs b/rust/src/mask/seed.rs index 92fcf63d4..b72c4c714 100644 --- a/rust/src/mask/seed.rs +++ b/rust/src/mask/seed.rs @@ -7,7 +7,7 @@ use sodiumoxide::{crypto::box_, randombytes::randombytes}; use crate::{ crypto::{generate_integer, ByteObject, SEALBYTES}, - mask::{config::MaskConfig, Integers, Mask}, + mask::{config::MaskConfig, Mask}, PetError, SumParticipantEphemeralPublicKey, SumParticipantEphemeralSecretKey, @@ -25,7 +25,7 @@ impl ByteObject for MaskSeed { /// Create a mask seed initialized to zero. fn zeroed() -> Self { - Self(box_::Seed([0_u8; Self::BYTES])) + Self(box_::Seed([0_u8; Self::LENGTH])) } /// Get the mask seed as a slice. @@ -35,17 +35,17 @@ impl ByteObject for MaskSeed { } impl MaskSeed { - /// Get the number of bytes of a mask seed. - pub const BYTES: usize = box_::SEEDBYTES; + /// Length in bytes of a [`MaskSeed`] + pub const LENGTH: usize = box_::SEEDBYTES; /// Generate a random mask seed. pub fn generate() -> Self { // safe unwrap: length of slice is guaranteed by constants - Self::from_slice_unchecked(randombytes(Self::BYTES).as_slice()) + Self::from_slice_unchecked(randombytes(Self::LENGTH).as_slice()) } /// Get the mask seed as an array. - pub fn as_array(&self) -> [u8; Self::BYTES] { + pub fn as_array(&self) -> [u8; Self::LENGTH] { (self.0).0 } @@ -56,13 +56,12 @@ impl MaskSeed { } /// Derive a mask of given length from the seed wrt the mask configuration. - pub fn derive_mask(&self, len: usize, config: &MaskConfig) -> Mask { + pub fn derive_mask(&self, len: usize, config: MaskConfig) -> Mask { let mut prng = ChaCha20Rng::from_seed(self.as_array()); - let integers = iter::repeat_with(|| generate_integer(&mut prng, config.order())) + let data = iter::repeat_with(|| generate_integer(&mut prng, &config.order())) .take(len) .collect(); - // safe unwrap: integer conformity is guaranteed by number generator - Mask::from_parts(integers, config.clone()).unwrap() + Mask::new(config, data) } } @@ -70,11 +69,17 @@ impl MaskSeed { /// An encrypted mask seed. pub struct EncryptedMaskSeed(Vec); +impl From> for EncryptedMaskSeed { + fn from(value: Vec) -> Self { + Self(value) + } +} + impl ByteObject for EncryptedMaskSeed { /// Create an encrypted mask seed from a slice of bytes. Fails if the length of the input is /// invalid. fn from_slice(bytes: &[u8]) -> Option { - if bytes.len() == Self::BYTES { + if bytes.len() == Self::LENGTH { Some(Self(bytes.to_vec())) } else { None @@ -83,7 +88,7 @@ impl ByteObject for EncryptedMaskSeed { /// Create an encrypted mask seed initialized to zero. fn zeroed() -> Self { - Self(vec![0_u8; Self::BYTES]) + Self(vec![0_u8; Self::LENGTH]) } /// Get the encrypted mask seed as a slice. @@ -94,7 +99,7 @@ impl ByteObject for EncryptedMaskSeed { impl EncryptedMaskSeed { /// Get the number of bytes of an encrypted mask seed. - pub const BYTES: usize = SEALBYTES + MaskSeed::BYTES; + pub const LENGTH: usize = SEALBYTES + MaskSeed::LENGTH; /// Decrypt an encrypted mask seed. Fails if the decryption fails. pub fn decrypt( @@ -116,17 +121,17 @@ mod tests { use super::*; use crate::{ crypto::generate_encrypt_key_pair, - mask::config::{BoundType, DataType, GroupType, MaskConfigs, ModelType}, + mask::config::{BoundType, DataType, GroupType, MaskConfig, ModelType}, }; #[test] fn test_constants() { - assert_eq!(MaskSeed::BYTES, 32); + assert_eq!(MaskSeed::LENGTH, 32); assert_eq!( MaskSeed::zeroed().as_slice(), [0_u8; 32].to_vec().as_slice(), ); - assert_eq!(EncryptedMaskSeed::BYTES, 80); + assert_eq!(EncryptedMaskSeed::LENGTH, 80); assert_eq!( EncryptedMaskSeed::zeroed().as_slice(), [0_u8; 80].to_vec().as_slice(), @@ -135,13 +140,12 @@ mod tests { #[test] fn test_derive_mask() { - let config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); + let config = MaskConfig { + group_type: GroupType::Prime, + data_type: DataType::F32, + bound_type: BoundType::B0, + model_type: ModelType::M3, + }; let seed = MaskSeed::generate(); let mask = seed.derive_mask(10, &config); assert_eq!(mask.integers().len(), 10); diff --git a/rust/src/message/mod.rs b/rust/src/message/mod.rs index 25b48c041..56b08aa5b 100644 --- a/rust/src/message/mod.rs +++ b/rust/src/message/mod.rs @@ -6,15 +6,8 @@ pub use self::traits::{FromBytes, LengthValueBuffer, ToBytes}; mod buffer; pub use self::buffer::*; -#[repr(u8)] -/// Message tags. -enum Tag { - #[allow(dead_code)] // None is used for tests - None, - Sum, - Update, - Sum2, -} +mod header; +pub use self::header::*; pub(crate) mod payload; pub use self::payload::*; diff --git a/rust/src/message/payload/update.rs b/rust/src/message/payload/update.rs index 87963fd26..dbf676076 100644 --- a/rust/src/message/payload/update.rs +++ b/rust/src/message/payload/update.rs @@ -301,7 +301,12 @@ impl FromBytes for UpdateOwned { #[cfg(test)] pub(crate) mod tests { use super::*; - use crate::{crypto::ByteObject, mask::MaskedModel, EncrMaskSeed, SumParticipantPublicKey}; + use crate::{ + crypto::ByteObject, + mask::MaskedModel, + EncryptedMaskSeed, + SumParticipantPublicKey, + }; use std::convert::TryFrom; fn sum_signature_bytes() -> Vec { @@ -326,11 +331,11 @@ pub(crate) mod tests { bytes.extend(vec![0x00, 0x00, 0x00, 0xe4]); bytes.extend(vec![0x55; SumParticipantPublicKey::LENGTH]); // sum participant pk - bytes.extend(vec![0x66; EncrMaskSeed::BYTES]); // ephemeral pk + bytes.extend(vec![0x66; EncryptedMaskSeed::BYTES]); // ephemeral pk // Second entry bytes.extend(vec![0x77; SumParticipantPublicKey::LENGTH]); // sum participant pk - bytes.extend(vec![0x88; EncrMaskSeed::BYTES]); // ephemeral pk + bytes.extend(vec![0x88; EncryptedMaskSeed::BYTES]); // ephemeral pk bytes } @@ -354,11 +359,11 @@ pub(crate) mod tests { local_seed_dict.insert( SumParticipantPublicKey::from_slice(vec![0x55; 32].as_slice()).unwrap(), - EncrMaskSeed::try_from(vec![0x66; EncrMaskSeed::BYTES]).unwrap(), + EncryptedMaskSeed::try_from(vec![0x66; EncryptedMaskSeed::BYTES]).unwrap(), ); local_seed_dict.insert( SumParticipantPublicKey::from_slice(vec![0x77; 32].as_slice()).unwrap(), - EncrMaskSeed::try_from(vec![0x88; EncrMaskSeed::BYTES]).unwrap(), + EncryptedMaskSeed::try_from(vec![0x88; EncryptedMaskSeed::BYTES]).unwrap(), ); UpdateOwned { diff --git a/rust/src/message/traits.rs b/rust/src/message/traits.rs index e5f314832..6fc5c3179 100644 --- a/rust/src/message/traits.rs +++ b/rust/src/message/traits.rs @@ -1,8 +1,8 @@ use super::DecodeError; -use crate::{crypto::ByteObject, EncrMaskSeed, LocalSeedDict, SumParticipantPublicKey}; +use crate::{crypto::ByteObject, mask::EncryptedMaskSeed, LocalSeedDict, SumParticipantPublicKey}; use anyhow::{anyhow, Context}; use std::{ - convert::{TryFrom, TryInto}, + convert::TryInto, io::{Cursor, Write}, ops::Range, }; @@ -286,10 +286,8 @@ macro_rules! impl_traits_for_length_value_types { } impl_traits_for_length_value_types!(crate::certificate::Certificate); -impl_traits_for_length_value_types!(crate::mask::Mask); -impl_traits_for_length_value_types!(crate::mask::MaskedModel); -const ENTRY_LENGTH: usize = SumParticipantPublicKey::LENGTH + EncrMaskSeed::BYTES; +const ENTRY_LENGTH: usize = SumParticipantPublicKey::LENGTH + EncryptedMaskSeed::LENGTH; impl ToBytes for LocalSeedDict { fn buffer_length(&self) -> usize { @@ -319,7 +317,7 @@ impl FromBytes for LocalSeedDict { // safe unwraps: lengths of slices are guaranteed // by constants. let key = SumParticipantPublicKey::from_slice(&chunk[..key_length]).unwrap(); - let value = EncrMaskSeed::try_from(&chunk[key_length..]).unwrap(); + let value = EncryptedMaskSeed::from_slice(&chunk[key_length..]).unwrap(); if dict.insert(key, value).is_some() { return Err(anyhow!("invalid local seed dictionary: duplicated key")); } diff --git a/rust/src/model.rs b/rust/src/model.rs deleted file mode 100644 index c7fd709f8..000000000 --- a/rust/src/model.rs +++ /dev/null @@ -1,511 +0,0 @@ -use std::convert::TryFrom; - -use num::{bigint::BigInt, clamp, rational::Ratio, traits::identities::Zero}; -use rand::SeedableRng; -use rand_chacha::ChaCha20Rng; - -use crate::{ - crypto::generate_integer, - mask::{config::MaskConfig, seed::MaskSeed, Integers, MaskedModel}, - PetError, -}; - -/// Masking of models. -pub trait MaskModels { - /// Get a reference to the weights. - fn weights(&self) -> &Vec; - - /// Cast the weights as ratios. Must handle non-finite weights. - fn as_ratios(&self) -> Vec>; - - /// Mask the model wrt the mask configuration. Enforces bounds on the scalar and weights. - /// - /// The masking proceeds in the following steps: - /// - clamp the scalar and the weights according to the mask configuration - /// - shift the weights into the non-negative reals - /// - shift the weights into the non-negative integers - /// - shift the weights into the finite group - /// - mask the weights with random elements from the finite group - /// - /// The random elements are derived from a seeded PRNG. Unmasking proceeds in reverse order. For - /// a more detailes see [the confluence page](https://xainag.atlassian.net/wiki/spaces/FP/pages/542408769/Masking). - fn mask(&self, scalar: f64, config: &MaskConfig) -> (MaskSeed, MaskedModel) { - let scalar = &Ratio::::from_float(clamp(scalar, 0_f64, 1_f64)).unwrap(); - let negative_bound = &-config.add_shift(); - let positive_bound = config.add_shift(); - let mask_seed = MaskSeed::generate(); - let mut prng = ChaCha20Rng::from_seed(mask_seed.as_array()); - let masked_weights = self - .as_ratios() - .iter() - .map(|weight| { - let scaled = scalar * clamp(weight, negative_bound, positive_bound); - let shifted = ((scaled + config.add_shift()) * config.exp_shift()) - .to_integer() - .to_biguint() - // safe unwrap: shifted weight is guaranteed to be non-negative - .unwrap(); - (shifted + generate_integer(&mut prng, config.order())) % config.order() - }) - .collect(); - // safe unwrap: masked weights are guaranteed to conform to the mask configuration - let masked_model = MaskedModel::from_parts(masked_weights, config.clone()).unwrap(); - (mask_seed, masked_model) - } -} - -#[derive(Clone, Debug, PartialEq, Eq)] -/// A model with weights represented as a vector of primitive numbers. -pub struct Model { - weights: Vec, -} - -impl TryFrom> for Model { - type Error = PetError; - - /// Create a model from its weights. Fails if the weights are not finite. - fn try_from(weights: Vec) -> Result { - if weights.iter().all(|weight| weight.is_finite()) { - Ok(Self { weights }) - } else { - Err(Self::Error::InvalidModel) - } - } -} - -impl TryFrom> for Model { - type Error = PetError; - - /// Create a model from its weights. Fails if the weights are not finite. - fn try_from(weights: Vec) -> Result { - if weights.iter().all(|weight| weight.is_finite()) { - Ok(Self { weights }) - } else { - Err(Self::Error::InvalidModel) - } - } -} - -impl From> for Model { - /// Create a model from its weights. - fn from(weights: Vec) -> Self { - Self { weights } - } -} - -impl From> for Model { - /// Create a model from its weights. - fn from(weights: Vec) -> Self { - Self { weights } - } -} - -impl MaskModels for Model { - /// Get a reference to the weights. - fn weights(&self) -> &Vec { - &self.weights - } - - /// Cast the weights as ratios. Positive/negative infinity is mapped to max/min and NaN to zero. - fn as_ratios(&self) -> Vec> { - self.weights - .iter() - .map(|weight| { - if weight.is_nan() { - Ratio::::zero() - } else { - // safe unwrap: clamped weight is guaranteed to be finite - Ratio::::from_float(clamp(*weight, f32::MIN, f32::MAX)).unwrap() - } - }) - .collect() - } -} - -impl MaskModels for Model { - /// Get a reference to the weights. - fn weights(&self) -> &Vec { - &self.weights - } - - /// Cast the weights as ratios. Positve/negative infinity is mapped to max/min and NaN to zero. - fn as_ratios(&self) -> Vec> { - self.weights - .iter() - .map(|weight| { - if weight.is_nan() { - Ratio::::zero() - } else { - // safe unwrap: clamped weight is guaranteed to be finite - Ratio::::from_float(clamp(*weight, f64::MIN, f64::MAX)).unwrap() - } - }) - .collect() - } -} - -impl MaskModels for Model { - /// Get a reference to the weights. - fn weights(&self) -> &Vec { - &self.weights - } - - /// Cast the weights as ratios. - fn as_ratios(&self) -> Vec> { - self.weights - .iter() - .map(|weight| Ratio::from_integer(BigInt::from(*weight))) - .collect() - } -} - -impl MaskModels for Model { - /// Get a reference to the weights. - fn weights(&self) -> &Vec { - &self.weights - } - - /// Cast the weights as ratios. - fn as_ratios(&self) -> Vec> { - self.weights - .iter() - .map(|weight| Ratio::from_integer(BigInt::from(*weight))) - .collect() - } -} - -#[cfg(test)] -mod tests { - use std::iter; - - use rand::distributions::{Distribution, Uniform}; - - use super::*; - use crate::mask::{ - config::{ - BoundType::{Bmax, B0, B2, B4, B6}, - DataType::{F32, F64, I32, I64}, - GroupType::{Integer, Power2, Prime}, - MaskConfigs, - ModelType::{M12, M3, M6, M9}, - }, - seed::MaskSeed, - MaskIntegers, - }; - - #[test] - fn test_model_f32() { - let weights = vec![-1_f32, 0_f32, 1_f32]; - let model = Model::::try_from(weights.clone()).unwrap(); - assert_eq!(model.weights(), &weights); - assert_eq!( - model.as_ratios(), - vec![ - Ratio::::from_float(-1_f32).unwrap(), - Ratio::::zero(), - Ratio::::from_float(1_f32).unwrap(), - ], - ); - } - - #[test] - fn test_model_f32_inf() { - let weights = vec![-1_f32, 0_f32, f32::INFINITY]; - assert_eq!( - Model::::try_from(weights.clone()).unwrap_err(), - PetError::InvalidModel, - ); - assert_eq!( - Model:: { weights }.as_ratios(), - vec![ - Ratio::::from_float(-1_f32).unwrap(), - Ratio::::zero(), - Ratio::::from_float(f32::MAX).unwrap(), - ], - ); - } - - #[test] - fn test_model_f32_neginf() { - let weights = vec![f32::NEG_INFINITY, 0_f32, 1_f32]; - assert_eq!( - Model::::try_from(weights.clone()).unwrap_err(), - PetError::InvalidModel, - ); - assert_eq!( - Model:: { weights }.as_ratios(), - vec![ - Ratio::::from_float(f32::MIN).unwrap(), - Ratio::::zero(), - Ratio::::from_float(1_f32).unwrap(), - ], - ); - } - - #[test] - fn test_model_f32_nan() { - let weights = vec![-1_f32, f32::NAN, 1_f32]; - assert_eq!( - Model::::try_from(weights.clone()).unwrap_err(), - PetError::InvalidModel, - ); - assert_eq!( - Model:: { weights }.as_ratios(), - vec![ - Ratio::::from_float(-1_f32).unwrap(), - Ratio::::zero(), - Ratio::::from_float(1_f32).unwrap(), - ], - ); - } - - #[test] - fn test_model_f64() { - let weights = vec![-1_f64, 0_f64, 1_f64]; - let model = Model::::try_from(weights.clone()).unwrap(); - assert_eq!(model.weights(), &weights); - assert_eq!( - model.as_ratios(), - vec![ - Ratio::::from_float(-1_f64).unwrap(), - Ratio::::zero(), - Ratio::::from_float(1_f64).unwrap(), - ], - ); - } - - #[test] - fn test_model_f64_inf() { - let weights = vec![-1_f64, 0_f64, f64::INFINITY]; - assert_eq!( - Model::::try_from(weights.clone()).unwrap_err(), - PetError::InvalidModel, - ); - assert_eq!( - Model:: { weights }.as_ratios(), - vec![ - Ratio::::from_float(-1_f64).unwrap(), - Ratio::::zero(), - Ratio::::from_float(f64::MAX).unwrap(), - ], - ); - } - - #[test] - fn test_model_f64_neginf() { - let weights = vec![f64::NEG_INFINITY, 0_f64, 1_f64]; - assert_eq!( - Model::::try_from(weights.clone()).unwrap_err(), - PetError::InvalidModel, - ); - assert_eq!( - Model:: { weights }.as_ratios(), - vec![ - Ratio::::from_float(f64::MIN).unwrap(), - Ratio::::zero(), - Ratio::::from_float(1_f64).unwrap(), - ], - ); - } - - #[test] - fn test_model_f64_nan() { - let weights = vec![-1_f64, f64::NAN, 1_f64]; - assert_eq!( - Model::::try_from(weights.clone()).unwrap_err(), - PetError::InvalidModel, - ); - assert_eq!( - Model:: { weights }.as_ratios(), - vec![ - Ratio::::from_float(-1_f64).unwrap(), - Ratio::::zero(), - Ratio::::from_float(1_f64).unwrap(), - ], - ); - } - - #[test] - fn test_model_i32() { - let weights = vec![-1_i32, 0_i32, 1_i32]; - let model = Model::::from(weights.clone()); - assert_eq!(model.weights(), &weights); - assert_eq!( - model.as_ratios(), - vec![ - Ratio::from_integer(BigInt::from(-1_i32)), - Ratio::::zero(), - Ratio::from_integer(BigInt::from(1_i32)), - ], - ); - } - - #[test] - fn test_model_i64() { - let weights = vec![-1_i64, 0_i64, 1_i64]; - let model = Model::::from(weights.clone()); - assert_eq!(model.weights(), &weights); - assert_eq!( - model.as_ratios(), - vec![ - Ratio::from_integer(BigInt::from(-1_i64)), - Ratio::::zero(), - Ratio::from_integer(BigInt::from(1_i64)), - ], - ); - } - - /// Generate tests for masking and unmasking. The tests proceed in the following steps: - /// - generate random weights from a uniform distribution with a seeded PRNG - /// - create a model from the weights and mask it - /// - check that all masked weights belong to the chosen finite group - /// - unmask the masked model - /// - check that all unmasked weights are equal to the original weights (up to a tolerance) - /// - /// The arguments to the macro are: - /// - a suffix for the test name - /// - the group type of the model (variants of `GroupType`) - /// - the data type of the model (either primitives or variants of `DataType`) - /// - an absolute bound for the weights (optional, choices: 1, 100, 10_000, 1_000_000) - /// - the number of weights - /// - a tolerance for the equality check - /// - /// For float data types the error depends on the order of magnitude of the weights, therefore - /// it may be necessary to raise the tolerance or bound the random weights if this test fails. - macro_rules! test_masking { - ($suffix:ident, $group:ty, $data:ty, $bound:expr, $len:expr, $tol:expr $(,)?) => { - paste::item! { - #[test] - fn []() { - paste::expr! { - let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); - let uniform = if $bound == 0 { - Uniform::new([<$data:lower>]::MIN, [<$data:lower>]::MAX) - } else { - Uniform::new(-$bound as [<$data:lower>], $bound as [<$data:lower>]) - }; - let weights = iter::repeat_with(|| uniform.sample(&mut prng)) - .take($len as usize) - .collect::>(); - let model = Model::try_from(weights).unwrap(); - let bound_type = match $bound { - 1 => B0, - 100 => B2, - 10_000 => B4, - 1_000_000 => B6, - 0 => Bmax, - _ => panic!("Unknown bound!") - }; - let config = MaskConfigs::from_parts( - $group, - [<$data:upper>], - bound_type, - M3 - ).config(); - let (mask_seed, masked_model) = model.mask(1_f64, &config); - assert_eq!( - masked_model.integers().len(), - $len as usize - ); - assert!( - masked_model - .integers() - .iter() - .all(|integer| integer < config.order()) - ); - let mask = mask_seed.derive_mask($len as usize, &config); - let unmasked_model: Model<[<$data:lower>]> = masked_model - .unmask(&mask, 1_usize) - .unwrap(); - assert!( - model - .weights() - .iter() - .zip(unmasked_model.weights().iter()) - .all( - |(weight, unmasked_weight)| - (weight - unmasked_weight).abs() - <= $tol as [<$data:lower>] - ) - ); - } - } - } - }; - ($suffix:ident, $group:ty, $data:ty, $len:expr, $tol:expr $(,)?) => { - test_masking!($suffix, $group, $data, 0, $len, $tol); - }; - } - - test_masking!(int_f32_b0, Integer, f32, 1, 10, 1e-3); - test_masking!(int_f32_b2, Integer, f32, 100, 10, 1e-3); - test_masking!(int_f32_b4, Integer, f32, 10_000, 10, 1e-3); - test_masking!(int_f32_b6, Integer, f32, 1_000_000, 10, 1e-3); - test_masking!(int_f32_bmax, Integer, f32, 10, 1e-3); - - test_masking!(prime_f32_b0, Prime, f32, 1, 10, 1e-3); - test_masking!(prime_f32_b2, Prime, f32, 100, 10, 1e-3); - test_masking!(prime_f32_b4, Prime, f32, 10_000, 10, 1e-3); - test_masking!(prime_f32_b6, Prime, f32, 1_000_000, 10, 1e-3); - test_masking!(prime_f32_bmax, Prime, f32, 10, 1e-3); - - test_masking!(pow_f32_b0, Power2, f32, 1, 10, 1e-3); - test_masking!(pow_f32_b2, Power2, f32, 100, 10, 1e-3); - test_masking!(pow_f32_b4, Power2, f32, 10_000, 10, 1e-3); - test_masking!(pow_f32_b6, Power2, f32, 1_000_000, 10, 1e-3); - test_masking!(pow_f32_bmax, Power2, f32, 10, 1e-3); - - test_masking!(int_f64_b0, Integer, f64, 1, 10, 1e-3); - test_masking!(int_f64_b2, Integer, f64, 100, 10, 1e-3); - test_masking!(int_f64_b4, Integer, f64, 10_000, 10, 1e-3); - test_masking!(int_f64_b6, Integer, f64, 1_000_000, 10, 1e-3); - test_masking!(int_f64_bmax, Integer, f64, 10, 1e-3); - - test_masking!(prime_f64_b0, Prime, f64, 1, 10, 1e-3); - test_masking!(prime_f64_b2, Prime, f64, 100, 10, 1e-3); - test_masking!(prime_f64_b4, Prime, f64, 10_000, 10, 1e-3); - test_masking!(prime_f64_b6, Prime, f64, 1_000_000, 10, 1e-3); - test_masking!(prime_f64_bmax, Prime, f64, 10, 1e-3); - - test_masking!(pow_f64_b0, Power2, f64, 1, 10, 1e-3); - test_masking!(pow_f64_b2, Power2, f64, 100, 10, 1e-3); - test_masking!(pow_f64_b4, Power2, f64, 10_000, 10, 1e-3); - test_masking!(pow_f64_b6, Power2, f64, 1_000_000, 10, 1e-3); - test_masking!(pow_f64_bmax, Power2, f64, 10, 1e-3); - - test_masking!(int_i32_b0, Integer, i32, 1, 10, 1e-3); - test_masking!(int_i32_b2, Integer, i32, 100, 10, 1e-3); - test_masking!(int_i32_b4, Integer, i32, 10_000, 10, 1e-3); - test_masking!(int_i32_b6, Integer, i32, 1_000_000, 10, 1e-3); - test_masking!(int_i32_bmax, Integer, i32, 10, 1e-3); - - test_masking!(prime_i32_b0, Prime, i32, 1, 10, 1e-3); - test_masking!(prime_i32_b2, Prime, i32, 100, 10, 1e-3); - test_masking!(prime_i32_b4, Prime, i32, 10_000, 10, 1e-3); - test_masking!(prime_i32_b6, Prime, i32, 1_000_000, 10, 1e-3); - test_masking!(prime_i32_bmax, Prime, i32, 10, 1e-3); - - test_masking!(pow_i32_b0, Power2, i32, 1, 10, 1e-3); - test_masking!(pow_i32_b2, Power2, i32, 100, 10, 1e-3); - test_masking!(pow_i32_b4, Power2, i32, 10_000, 10, 1e-3); - test_masking!(pow_i32_b6, Power2, i32, 1_000_000, 10, 1e-3); - test_masking!(pow_i32_bmax, Power2, i32, 10, 1e-3); - - test_masking!(int_i64_b0, Integer, i64, 1, 10, 1e-3); - test_masking!(int_i64_b2, Integer, i64, 100, 10, 1e-3); - test_masking!(int_i64_b4, Integer, i64, 10_000, 10, 1e-3); - test_masking!(int_i64_b6, Integer, i64, 1_000_000, 10, 1e-3); - test_masking!(int_i64_bmax, Integer, i64, 10, 1e-3); - - test_masking!(prime_i64_b0, Prime, i64, 1, 10, 1e-3); - test_masking!(prime_i64_b2, Prime, i64, 100, 10, 1e-3); - test_masking!(prime_i64_b4, Prime, i64, 10_000, 10, 1e-3); - test_masking!(prime_i64_b6, Prime, i64, 1_000_000, 10, 1e-3); - test_masking!(prime_i64_bmax, Prime, i64, 10, 1e-3); - - test_masking!(pow_i64_b0, Power2, i64, 1, 10, 1e-3); - test_masking!(pow_i64_b2, Power2, i64, 100, 10, 1e-3); - test_masking!(pow_i64_b4, Power2, i64, 10_000, 10, 1e-3); - test_masking!(pow_i64_b6, Power2, i64, 1_000_000, 10, 1e-3); - test_masking!(pow_i64_bmax, Power2, i64, 10, 1e-3); -} diff --git a/rust/src/participant.rs b/rust/src/participant.rs index 99698a596..90fd4bfcf 100644 --- a/rust/src/participant.rs +++ b/rust/src/participant.rs @@ -1,16 +1,25 @@ -use std::{convert::TryFrom, default::Default}; +use std::default::Default; + +use sodiumoxide; use crate::{ certificate::Certificate, crypto::{generate_encrypt_key_pair, generate_signing_key_pair, ByteObject}, mask::{ - config::{BoundType, DataType, GroupType, MaskConfig, MaskConfigs, ModelType}, - seed::MaskSeed, - Integers, + BoundType, + DataType, + FromPrimitives, + GroupType, Mask, + MaskConfig, + MaskSeed, + MaskedModel, + Masker, + Model, + ModelType, }, - message::{sum::SumMessage, sum2::Sum2Message, update::UpdateMessage}, - model::{MaskModels, Model}, + message::{MessageOwned, MessageSeal, Sum2Owned, SumOwned, UpdateOwned}, + utils::is_eligible, CoordinatorPublicKey, InitError, LocalSeedDict, @@ -24,7 +33,7 @@ use crate::{ SumParticipantEphemeralSecretKey, }; -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq)] /// Tasks of a participant. enum Task { Sum, @@ -53,7 +62,7 @@ impl Default for Participant { let sk = ParticipantSecretKey::zeroed(); let ephm_pk = SumParticipantEphemeralPublicKey::zeroed(); let ephm_sk = SumParticipantEphemeralSecretKey::zeroed(); - let certificate = Certificate::zeroed(); + let certificate = Certificate::new(); let sum_signature = ParticipantTaskSignature::zeroed(); let update_signature = ParticipantTaskSignature::zeroed(); let task = Task::None; @@ -91,9 +100,9 @@ impl Participant { /// Check eligibility for a task. pub fn check_task(&mut self, round_sum: f64, round_update: f64) { - if self.sum_signature.is_eligible(round_sum) { + if is_eligible(&self.sum_signature, round_sum) { self.task = Task::Sum; - } else if self.update_signature.is_eligible(round_update) { + } else if is_eligible(&self.update_signature, round_update) { self.task = Task::Update; } else { self.task = Task::None; @@ -114,18 +123,8 @@ impl Participant { } /// Compose an update message. - pub fn compose_update_message(&self, pk: &CoordinatorPublicKey, sum_dict: &SumDict) -> Vec { - let model = Model::try_from(vec![0_f32, 0.5, -0.5]).unwrap(); // dummy - let scalar = 0.5_f64; // dummy - let mask_config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); - // safe unwrap: data types of model and mask configuration conform due to definition above - let (mask_seed, masked_model) = model.mask(scalar, &mask_config); + pub fn compose_update_message(&self, pk: CoordinatorPublicKey, sum_dict: &SumDict) -> Vec { + let (mask_seed, masked_model) = Self::mask_model(); let local_seed_dict = Self::create_local_seed_dict(sum_dict, &mask_seed); let payload = UpdateOwned { @@ -146,19 +145,22 @@ impl Participant { seed_dict: &SeedDict, ) -> Result, PetError> { let mask_seeds = self.get_seeds(seed_dict)?; - let mask_len = 3; // dummy - let mask_config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); - let mask = self.compute_global_mask(mask_seeds, mask_len, &mask_config)?; - Ok( - Sum2Message::from_parts(&self.pk, &self.sum_signature, &self.certificate, &mask) - .seal(&self.sk, pk), - ) + let mask = self.compute_global_mask(mask_seeds); + let payload = Sum2Owned { + mask: mask, + sum_signature: self.sum_signature, + }; + + let message = MessageOwned::new_sum2(pk, self.pk, payload); + Ok(self.seal_message(&pk, &message)) + } + + fn seal_message(&self, pk: &CoordinatorPublicKey, message: &MessageOwned) -> Vec { + let message_seal = MessageSeal { + recipient_pk: pk, + sender_sk: &self.sk, + }; + message_seal.seal(message) } /// Generate an ephemeral encryption key pair. @@ -168,6 +170,14 @@ impl Participant { self.ephm_sk = ephm_sk; } + /// Generate a mask seed and mask a local model (dummy). + fn mask_model() -> (MaskSeed, MaskedModel) { + let model = Model::from_primitives(vec![0_i32, 1_i32, 2_i32, 3_i32].into_iter()).unwrap(); + let masker = Masker::new(dummy_config()); + let (seed, masked_model) = masker.mask_model(0.5, model); + (seed, masked_model) + } + // Create a local seed dictionary from a sum dictionary. fn create_local_seed_dict(sum_dict: &SumDict, mask_seed: &MaskSeed) -> LocalSeedDict { sum_dict @@ -186,23 +196,9 @@ impl Participant { .collect() } - /// Compute a global mask from local mask seeds. - fn compute_global_mask( - &self, - mask_seeds: Vec, - mask_len: usize, - mask_config: &MaskConfig, - ) -> Result { - if !mask_seeds.is_empty() { - let mut global_mask = mask_seeds[0].derive_mask(mask_len, mask_config); - for mask_seed in mask_seeds[1..].iter() { - global_mask = - global_mask.aggregate(&mask_seed.derive_mask(mask_len, mask_config))?; - } - Ok(global_mask) - } else { - Err(PetError::InvalidMask) - } + /// Compute a global mask from local mask seeds (dummy). + fn compute_global_mask(&self, _mask_seeds: Vec) -> Mask { + Mask::new_empty(dummy_config()) } } @@ -213,7 +209,7 @@ mod tests { iter, }; - use sodiumoxide::randombytes::{randombytes, randombytes_uniform}; + use sodiumoxide::randombytes::randombytes_uniform; use super::*; use crate::{crypto::Signature, SumParticipantPublicKey, UpdateParticipantPublicKey}; @@ -225,7 +221,7 @@ mod tests { assert_eq!(part.sk.as_slice().len(), 64); assert_eq!(part.ephm_pk, SumParticipantEphemeralPublicKey::zeroed()); assert_eq!(part.ephm_sk, SumParticipantEphemeralSecretKey::zeroed()); - assert_eq!(part.certificate, Certificate::zeroed()); + assert_eq!(part.certificate, Certificate::new()); assert_eq!(part.sum_signature, ParticipantTaskSignature::zeroed()); assert_eq!(part.update_signature, ParticipantTaskSignature::zeroed()); assert_eq!(part.task, Task::None); @@ -249,29 +245,29 @@ mod tests { #[test] fn test_check_task() { let mut part = Participant::new().unwrap(); - let eligible_signature = Signature::from_slice_unchecked(&[ - 172, 29, 85, 219, 118, 44, 107, 32, 219, 253, 25, 242, 53, 45, 111, 62, 102, 130, 24, - 8, 222, 199, 34, 120, 166, 163, 223, 229, 100, 50, 252, 244, 250, 88, 196, 151, 136, - 48, 39, 198, 166, 86, 29, 151, 13, 81, 69, 198, 40, 148, 134, 126, 7, 202, 1, 56, 174, - 43, 89, 28, 242, 194, 4, 214, + let elligible_signature = Signature::from_slice_unchecked(&[ + 229, 191, 74, 163, 113, 6, 242, 191, 255, 225, 40, 89, 210, 94, 25, 50, 44, 129, 155, + 241, 99, 64, 25, 212, 157, 235, 102, 95, 115, 18, 158, 115, 253, 136, 178, 223, 4, 47, + 54, 162, 236, 78, 126, 114, 205, 217, 250, 163, 223, 149, 31, 65, 179, 179, 60, 64, 34, + 1, 78, 245, 1, 50, 165, 47, ]); - let ineligible_signature = Signature::from_slice_unchecked(&[ - 119, 2, 197, 174, 52, 165, 229, 22, 218, 210, 240, 188, 220, 232, 149, 129, 211, 13, - 61, 217, 186, 79, 102, 15, 109, 237, 83, 193, 12, 117, 210, 66, 99, 230, 30, 131, 63, - 108, 28, 222, 48, 92, 153, 71, 159, 220, 115, 181, 183, 155, 146, 182, 205, 89, 140, - 234, 100, 40, 199, 248, 23, 147, 172, 248, + let inelligible_signature = Signature::from_slice_unchecked(&[ + 15, 107, 81, 84, 105, 246, 165, 81, 76, 125, 140, 172, 113, 85, 51, 173, 119, 123, 78, + 114, 249, 182, 135, 212, 134, 38, 125, 153, 120, 45, 179, 55, 116, 155, 205, 51, 247, + 37, 78, 147, 63, 231, 28, 61, 251, 41, 48, 239, 125, 0, 129, 126, 194, 123, 183, 11, + 215, 220, 1, 225, 248, 131, 64, 242, ]); - part.sum_signature = eligible_signature; - part.update_signature = ineligible_signature; + part.sum_signature = elligible_signature; + part.update_signature = inelligible_signature; part.check_task(0.5_f64, 0.5_f64); assert_eq!(part.task, Task::Sum); - part.update_signature = eligible_signature; + part.update_signature = elligible_signature; part.check_task(0.5_f64, 0.5_f64); assert_eq!(part.task, Task::Sum); - part.sum_signature = ineligible_signature; + part.sum_signature = inelligible_signature; part.check_task(0.5_f64, 0.5_f64); assert_eq!(part.task, Task::Update); - part.update_signature = ineligible_signature; + part.update_signature = inelligible_signature; part.check_task(0.5_f64, 0.5_f64); assert_eq!(part.task, Task::None); } @@ -286,7 +282,7 @@ mod tests { #[test] fn test_create_local_seed_dict() { - let mask_seed = MaskSeed::generate(); + let (mask_seed, _) = Participant::mask_model(); let ephm_dict = iter::repeat_with(|| generate_encrypt_key_pair()) .take(1 + randombytes_uniform(10) as usize) .collect::>( @@ -295,7 +291,7 @@ mod tests { .iter() .map(|(ephm_pk, _)| { ( - SumParticipantPublicKey::from_slice_unchecked(&randombytes(32)), + SumParticipantPublicKey::from_slice(&randombytes(32)).unwrap(), *ephm_pk, ) }) @@ -306,7 +302,7 @@ mod tests { assert!(seed_dict.iter().all(|(pk, seed)| { let ephm_pk = sum_dict.get(pk).unwrap(); let ephm_sk = ephm_dict.get(ephm_pk).unwrap(); - mask_seed == seed.decrypt(ephm_pk, ephm_sk).unwrap() + mask_seed == seed.open(ephm_pk, ephm_sk).unwrap() })); } @@ -314,7 +310,7 @@ mod tests { fn test_get_seeds() { let mut part = Participant::new().unwrap(); part.gen_ephm_keypair(); - let mask_seeds = iter::repeat_with(|| MaskSeed::generate()) + let mask_seeds = iter::repeat_with(|| MaskSeed::new()) .take(1 + randombytes_uniform(10) as usize) .collect::>(); let seed_dict = [( @@ -323,8 +319,8 @@ mod tests { .iter() .map(|seed| { ( - UpdateParticipantPublicKey::from_slice_unchecked(&randombytes(32)), - seed.encrypt(&part.ephm_pk), + UpdateParticipantPublicKey::from_slice(&randombytes(32)).unwrap(), + seed.seal(&part.ephm_pk), ) }) .collect(), @@ -336,12 +332,8 @@ mod tests { part.get_seeds(&seed_dict) .unwrap() .into_iter() - .map(|seed| seed.as_array()) - .collect::>(), - mask_seeds - .into_iter() - .map(|seed| seed.as_array()) .collect::>(), + mask_seeds.into_iter().collect::>(), ); assert_eq!( part.get_seeds(&SeedDict::new()).unwrap_err(), @@ -349,3 +341,12 @@ mod tests { ); } } + +fn dummy_config() -> MaskConfig { + MaskConfig { + group_type: GroupType::Integer, + data_type: DataType::I32, + bound_type: BoundType::Bmax, + model_type: ModelType::M3, + } +} diff --git a/rust/src/service/data.rs b/rust/src/service/data.rs index 31e615374..5787ab3c3 100644 --- a/rust/src/service/data.rs +++ b/rust/src/service/data.rs @@ -1,10 +1,8 @@ use thiserror::Error; use crate::{ - coordinator::{Coordinator, ProtocolEvent, RoundParameters}, - model::Model, + coordinator::{ProtocolEvent, RoundParameters}, service::handle::{SerializedSeedDict, SerializedSumDict}, - MaskHash, SeedDict, SumParticipantPublicKey, }; @@ -38,6 +36,10 @@ pub enum PhaseData { /// Data held by the service during the sum2 phase #[from] Sum2(Sum2Data), + + /// Data held by the service during the aggregation phase + #[from] + Aggregation, } impl PhaseData { diff --git a/rust/src/service/mod.rs b/rust/src/service/mod.rs index c475c8d2f..579220962 100644 --- a/rust/src/service/mod.rs +++ b/rust/src/service/mod.rs @@ -1,8 +1,4 @@ -use crate::{ - coordinator::{Coordinator, Coordinators, MaskCoordinators, RoundParameters}, - InitError, -}; -use derive_more::From; +use crate::{coordinator::Coordinator, InitError}; use futures::ready; use std::{ future::Future, @@ -31,7 +27,7 @@ pub use handle::{ pub struct Service { /// The coordinator holds the protocol state: crypto material, sum /// and update dictionaries, configuration, etc. - coordinator: Coordinator, // todo: implement a choice for data types + coordinator: Coordinator, /// Events to handle events: EventStream, @@ -47,7 +43,7 @@ impl Service { let (handle, events) = Handle::new(); let service = Self { events, - coordinator: Coordinator::::new()?, + coordinator: Coordinator::new()?, data: Data::new(), }; Ok((service, handle)) diff --git a/rust/src/utils.rs b/rust/src/utils.rs index 8b1378917..19d1b70b7 100644 --- a/rust/src/utils.rs +++ b/rust/src/utils.rs @@ -1 +1,56 @@ +use crate::crypto::ByteObject; +use num::{ + bigint::{BigUint, ToBigInt}, + rational::Ratio, +}; +use sodiumoxide::crypto::hash::sha256; +use crate::ParticipantTaskSignature; + +/// Compute the floating point representation of the hashed signature and ensure that it +/// is below the given threshold: int(hash(signature)) / (2**hashbits - 1) <= threshold. +pub fn is_eligible(signature: &ParticipantTaskSignature, threshold: f64) -> bool { + if threshold < 0_f64 { + false + } else if threshold > 1_f64 { + true + } else { + Ratio::new( + BigUint::from_bytes_be(&sha256::hash(signature.as_slice()).0[..]) + .to_bigint() + // FIXME: can we unwrap here? + .unwrap(), + BigUint::from_bytes_be(&[255_u8; 32][..]) + .to_bigint() + .unwrap(), + ) <= Ratio::from_float(threshold).unwrap() + } +} + +#[cfg(test)] +mod tests { + use crate::crypto::Signature; + + use super::*; + + #[test] + fn test_is_eligible() { + // eligible signature + let sig = Signature::from_slice_unchecked(&[ + 229, 191, 74, 163, 113, 6, 242, 191, 255, 225, 40, 89, 210, 94, 25, 50, 44, 129, 155, + 241, 99, 64, 25, 212, 157, 235, 102, 95, 115, 18, 158, 115, 253, 136, 178, 223, 4, 47, + 54, 162, 236, 78, 126, 114, 205, 217, 250, 163, 223, 149, 31, 65, 179, 179, 60, 64, 34, + 1, 78, 245, 1, 50, 165, 47, + ]); + assert!(is_eligible(&sig, 0.5_f64)); + + // ineligible signature + let sig = Signature::from_slice_unchecked(&[ + 15, 107, 81, 84, 105, 246, 165, 81, 76, 125, 140, 172, 113, 85, 51, 173, 119, 123, 78, + 114, 249, 182, 135, 212, 134, 38, 125, 153, 120, 45, 179, 55, 116, 155, 205, 51, 247, + 37, 78, 147, 63, 231, 28, 61, 251, 41, 48, 239, 125, 0, 129, 126, 194, 123, 183, 11, + 215, 220, 1, 225, 248, 131, 64, 242, + ]); + assert!(!is_eligible(&sig, 0.5_f64)); + } +}