From e7ea51dcb152409cfe8dbc491ff5c7da9c8ff2fb Mon Sep 17 00:00:00 2001 From: little-dude Date: Thu, 30 Apr 2020 16:54:49 +0200 Subject: [PATCH] refactor messages Summary: ======== 1. add a `Header` type for the common fields 2. add a `LengthValueBuffer` type to handle the variable length fields 3. decouple the crypto and parsing parts 4. add a `Message` type that wraps sum, update and sum2 messages 5. have an `Owned` and a `Borrowed` version of every type 7. small bug fixes & improvements: - not using `usize` as a field since it's platform dependent - detection of truncated local seed dictionary - `XxxBuffer` does not allocate a `Vec` - remove impls for specific types (`impl TryFrom> for XxxBuffer`, `impl XxxBuffer>` etc) - make the `XxxBuffer()` APIs more consistent, opening the door to code generation via macros - use a custom `DecodeError` type that contains the whole stack of errors, which makes debugging easier Details: ======== 1. Introduce a `Header` type for the common fields -------------------------------------------------- All the messages share common fields. In networking protocol, these common fields are usually handled separately from the rest of the message. They are usually called headers, and the rest of the message is the payload. We defined the following header: ```rust /// A header, common to all the message pub struct Header { /// Type of message pub tag: Tag, /// Coordinator public key pub coordinator_pk: CK, /// Participant public key pub participant_pk: PK, /// A certificate that identifies the author of the message pub certificate: Option, } ``` Currently the code to handle these common fields lives in the `MessageBuffer` trait which is implemented for each message type, thus reducing boilerplate for parsing these fields. I can see several downsides to using a trait for this though: a. the `serialize()` and `deserialize()` methods must call the `MessageBuffer` methods for each of these common fields. b. it is difficult to handle variable length fields in the header (which maybe is why the optional certificate is handled by each message separately?) c. tests require full messages `a.` is not too much of a problem, and I'm not sure to what extent `b.` holds true. But removing the header logic does lead to much simpler tests. 2. Add a `LengthValueBuffer` type to handle the variable length fields ---------------------------------------------------------------------- This reduces error prone boilerplate, like we had in `update.rs`: ```rust if buffer.len() >= Self::LOCAL_SEED_DICT_LEN_RANGE.end { buffer.certificate_range = Self::LOCAL_SEED_DICT_LEN_RANGE.end ..Self::LOCAL_SEED_DICT_LEN_RANGE.end + usize::from_le_bytes(buffer.certificate_len().try_into().unwrap()); buffer.masked_model_range = buffer.certificate_range.end ..buffer.certificate_range.end + usize::from_le_bytes(buffer.masked_model_len().try_into().unwrap()); buffer.local_seed_dict_range = buffer.masked_model_range.end ..buffer.masked_model_range.end + usize::from_le_bytes(buffer.local_seed_dict_len().try_into().unwrap()); } ``` This also allows us to automate the serialization/deserialization of length-variable types like masks, certificates and masked models (see `impl_traits_for_length_value_types!`) 3. decouple the crypto and parsing parts ---------------------------------------- The message signature and encryption is the very last/first step for every message, and is always exactly the same, so it makes sense to keep it separate. Therefore we removed the `open()` and `seal()` methods for the messages themself and moved the logic to two dedicated types: `MessageSeal` for signing and encrypting messages, and `MessageOpener` for decrypting and verifying message signatures. With this, we could now move the crypto logic to the _transport_ layer, and only handle fully fledged messages in the business logic layer. 4. add a `Message` type that wraps sum, update and sum2 messages ---------------------------------------------------------------- ```rust pub struct Message { pub header: Header, pub payload: Payload, } pub enum Payload { Sum(SumMessage), Sum2(Sum2Message), Update(UpdateMessage), } ``` This is actually needed if we want to move message serialization/deserialization to Tokio. By doing so, we don't have to repeat the same code for the fields that are common to all the messages. 5. have an `Owned` and a `Borrowed` version of every type --------------------------------------------------------- Each type comes in two flavours: `Owned` and `Borrowed`. An `XxxOwned` type _owns_ their fields, while an `XxxBorrowed` type may only have references to some of the fields. The `Borrowed` variant is needed because we have some potentially large fields that we don't want to clone when emitting a message. For instance, it would be wasteful for an update participant sending an update message with a large local seed dictionary to clone that dictionary. --- rust/rustfmt.toml | 1 + rust/src/coordinator.rs | 24 +- rust/src/crypto/encrypt.rs | 17 +- rust/src/crypto/hash.rs | 45 ++ rust/src/crypto/mod.rs | 1 + rust/src/crypto/sign.rs | 12 + rust/src/lib.rs | 4 +- rust/src/message/buffer.rs | 367 ++++++++++++++ rust/src/message/header.rs | 116 +++++ rust/src/message/message.rs | 182 +++++++ rust/src/message/mod.rs | 134 +---- rust/src/message/payload/mod.rs | 52 ++ rust/src/message/payload/sum.rs | 287 +++++++++++ rust/src/message/payload/sum2.rs | 304 ++++++++++++ rust/src/message/payload/update.rs | 433 ++++++++++++++++ rust/src/message/sum.rs | 460 ----------------- rust/src/message/sum2.rs | 489 ------------------- rust/src/message/traits.rs | 405 +++++++++++++++ rust/src/message/update.rs | 759 ----------------------------- rust/src/message/utils.rs | 5 + rust/src/participant.rs | 36 +- rust/src/service/data.rs | 3 - rust/src/service/handle.rs | 1 - rust/src/service/mod.rs | 12 +- 24 files changed, 2273 insertions(+), 1876 deletions(-) create mode 100644 rust/src/crypto/hash.rs create mode 100644 rust/src/message/buffer.rs create mode 100644 rust/src/message/header.rs create mode 100644 rust/src/message/message.rs create mode 100644 rust/src/message/payload/mod.rs create mode 100644 rust/src/message/payload/sum.rs create mode 100644 rust/src/message/payload/sum2.rs create mode 100644 rust/src/message/payload/update.rs delete mode 100644 rust/src/message/sum.rs delete mode 100644 rust/src/message/sum2.rs create mode 100644 rust/src/message/traits.rs delete mode 100644 rust/src/message/update.rs create mode 100644 rust/src/message/utils.rs diff --git a/rust/rustfmt.toml b/rust/rustfmt.toml index b2750ab72..d35e3953e 100644 --- a/rust/rustfmt.toml +++ b/rust/rustfmt.toml @@ -1,2 +1,3 @@ imports_layout = "HorizontalVertical" merge_imports = true +format_code_in_doc_comments = true diff --git a/rust/src/coordinator.rs b/rust/src/coordinator.rs index 05dc57ffa..13e1afd9e 100644 --- a/rust/src/coordinator.rs +++ b/rust/src/coordinator.rs @@ -24,7 +24,6 @@ use crate::{ PetError, SeedDict, SumDict, - SumParticipantEphemeralPublicKey, SumParticipantPublicKey, UpdateParticipantPublicKey, }; @@ -199,6 +198,13 @@ pub trait Coordinators: Sized { self.events_mut().pop_front() } + fn message_open(&self) -> MessageOpen<'_, '_> { + MessageOpen { + recipient_pk: &self.pk, + recipient_sk: &self.sk, + } + } + /// Validate and handle a sum, update or sum2 message. fn handle_message(&mut self, bytes: &[u8]) -> Result<(), PetError> { match self.phase() { @@ -240,8 +246,8 @@ pub trait Coordinators: Sized { /// Validate a sum signature and its implied task. fn validate_sum_task( &self, - sum_signature: &ParticipantTaskSignature, pk: &SumParticipantPublicKey, + sum_signature: &ParticipantTaskSignature, ) -> Result<(), PetError> { if pk.verify_detached(sum_signature, &[self.seed().as_slice(), b"sum"].concat()) && sum_signature.is_eligible(*self.sum()) @@ -255,9 +261,9 @@ pub trait Coordinators: Sized { /// Validate an update signature and its implied task. fn validate_update_task( &self, + pk: &UpdateParticipantPublicKey, sum_signature: &ParticipantTaskSignature, update_signature: &ParticipantTaskSignature, - pk: &UpdateParticipantPublicKey, ) -> Result<(), PetError> { if pk.verify_detached(sum_signature, &[self.seed().as_slice(), b"sum"].concat()) && pk.verify_detached( @@ -621,7 +627,7 @@ impl MaskCoordinators for Coordinator { } } -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub struct RoundParameters { /// The coordinator public key for encryption. pub pk: CoordinatorPublicKey, @@ -704,7 +710,7 @@ mod tests { 148, 45, 56, 85, 65, 75, 128, 178, 175, 101, 93, 241, 162, ]); assert_eq!( - coord.validate_sum_task(&sum_signature, &pk).unwrap_err(), + coord.validate_sum_task(&pk, &sum_signature).unwrap_err(), PetError::InvalidMessage, ); } @@ -738,7 +744,7 @@ mod tests { ]); assert_eq!( coord - .validate_update_task(&sum_signature, &update_signature, &pk) + .validate_update_task(&pk, &sum_signature, &update_signature) .unwrap(), (), ); @@ -762,7 +768,7 @@ mod tests { ]); assert_eq!( coord - .validate_update_task(&sum_signature, &update_signature, &pk) + .validate_update_task(&pk, &sum_signature, &update_signature) .unwrap_err(), PetError::InvalidMessage, ); @@ -786,7 +792,7 @@ mod tests { ]); assert_eq!( coord - .validate_update_task(&sum_signature, &update_signature, &pk) + .validate_update_task(&pk, &sum_signature, &update_signature) .unwrap_err(), PetError::InvalidMessage, ); @@ -810,7 +816,7 @@ mod tests { ]); assert_eq!( coord - .validate_update_task(&sum_signature, &update_signature, &pk) + .validate_update_task(&pk, &sum_signature, &update_signature) .unwrap_err(), PetError::InvalidMessage, ); diff --git a/rust/src/crypto/encrypt.rs b/rust/src/crypto/encrypt.rs index 703a09567..4df4ad1e7 100644 --- a/rust/src/crypto/encrypt.rs +++ b/rust/src/crypto/encrypt.rs @@ -3,6 +3,10 @@ use super::ByteObject; use derive_more::{AsMut, AsRef, From}; use sodiumoxide::crypto::{box_, sealedbox}; +/// Number of additional bytes in a ciphertext compared to the +/// corresponding plaintext. +pub const SEALBYTES: usize = sealedbox::SEALBYTES; + /// Generate a new random key pair pub fn generate_encrypt_key_pair() -> (PublicEncryptKey, SecretEncryptKey) { let (pk, sk) = box_::gen_keypair(); @@ -41,9 +45,10 @@ impl ByteObject for PublicEncryptKey { } } -pub const SEALBYTES: usize = sealedbox::SEALBYTES; - impl PublicEncryptKey { + /// Length in bytes of a [`PublicEncryptKey`] + pub const LENGTH: usize = box_::PUBLICKEYBYTES; + /// Encrypt a message `m` with this public key. The resulting /// ciphertext length is [`SEALBYTES`] + `m.len()`. /// @@ -60,6 +65,9 @@ impl PublicEncryptKey { pub struct SecretEncryptKey(box_::SecretKey); impl SecretEncryptKey { + /// Length in bytes of a [`SecretEncryptKey`] + pub const LENGTH: usize = box_::SECRETKEYBYTES; + /// Decrypt the ciphertext `c` using this secret key and the /// associated public key, and return the decrypted message. /// @@ -94,6 +102,9 @@ impl ByteObject for SecretEncryptKey { pub struct EncryptKeySeed(box_::Seed); impl EncryptKeySeed { + /// Length in bytes of a [`EncryptKeySeed`] + pub const LENGTH: usize = box_::SEEDBYTES; + /// Deterministically derive a new key pair from this seed pub fn derive_encrypt_key_pair(&self) -> (PublicEncryptKey, SecretEncryptKey) { let (pk, sk) = box_::keypair_from_seed(self.as_ref()); @@ -107,7 +118,7 @@ impl ByteObject for EncryptKeySeed { } fn zeroed() -> Self { - Self(box_::Seed([0; box_::PUBLICKEYBYTES])) + Self(box_::Seed([0; box_::SEEDBYTES])) } fn as_slice(&self) -> &[u8] { diff --git a/rust/src/crypto/hash.rs b/rust/src/crypto/hash.rs new file mode 100644 index 000000000..040e22c2a --- /dev/null +++ b/rust/src/crypto/hash.rs @@ -0,0 +1,45 @@ +use super::ByteObject; + +use sodiumoxide::crypto::hash::sha256; + +use derive_more::{AsMut, AsRef, From}; + +#[derive( + AsRef, + AsMut, + From, + Serialize, + Deserialize, + Hash, + Eq, + Ord, + PartialEq, + Copy, + Clone, + PartialOrd, + Debug, +)] +pub struct Sha256(sha256::Digest); + +impl ByteObject for Sha256 { + fn zeroed() -> Self { + Self(sha256::Digest([0_u8; sha256::DIGESTBYTES])) + } + + fn as_slice(&self) -> &[u8] { + self.0.as_ref() + } + + fn from_slice(bytes: &[u8]) -> Option { + sha256::Digest::from_slice(bytes).map(Self) + } +} + +impl Sha256 { + /// Length in bytes of a [`Sha256`] + pub const LENGTH: usize = 32; + + pub fn hash(m: &[u8]) -> Self { + Self(sha256::hash(m)) + } +} diff --git a/rust/src/crypto/mod.rs b/rust/src/crypto/mod.rs index eb1247a2c..b12d0d8ab 100644 --- a/rust/src/crypto/mod.rs +++ b/rust/src/crypto/mod.rs @@ -2,6 +2,7 @@ //! primitives. mod encrypt; +mod hash; mod sign; use num::{bigint::BigUint, traits::identities::Zero}; diff --git a/rust/src/crypto/sign.rs b/rust/src/crypto/sign.rs index b7939e408..14b50c366 100644 --- a/rust/src/crypto/sign.rs +++ b/rust/src/crypto/sign.rs @@ -31,6 +31,8 @@ pub fn generate_signing_key_pair() -> (PublicSigningKey, SecretSigningKey) { pub struct PublicSigningKey(sign::PublicKey); impl PublicSigningKey { + /// Length in bytes of a [`PublicSigningKey`] + pub const LENGTH: usize = sign::PUBLICKEYBYTES; /// Verify the signature `s` against the message `m` and the /// signer's public key `&self`. /// @@ -62,6 +64,8 @@ impl ByteObject for PublicSigningKey { pub struct SecretSigningKey(sign::SecretKey); impl SecretSigningKey { + /// Length in bytes of a [`SecretSigningKey`] + pub const LENGTH: usize = sign::SECRETKEYBYTES; /// Sign a message `m` pub fn sign_detached(&self, m: &[u8]) -> Signature { sign::sign_detached(m, self.as_ref()).into() @@ -105,6 +109,11 @@ impl ByteObject for SecretSigningKey { )] pub struct Signature(sign::Signature); +impl Signature { + /// Length in bytes of a [`Signature`] + pub const LENGTH: usize = sign::SIGNATUREBYTES; +} + impl ByteObject for Signature { fn zeroed() -> Self { Self(sign::Signature([0_u8; sign::SIGNATUREBYTES])) @@ -146,6 +155,9 @@ impl Signature { pub struct SigningKeySeed(sign::Seed); impl SigningKeySeed { + /// Length in bytes of a [`SigningKeySeed`] + pub const LENGTH: usize = sign::SEEDBYTES; + /// Deterministically derive a new signing key pair from this seed pub fn derive_signing_key_pair(&self) -> (PublicSigningKey, SecretSigningKey) { let (pk, sk) = sign::keypair_from_seed(&self.0); diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 28da7a903..954fc7591 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -8,7 +8,7 @@ #[macro_use] extern crate tracing; #[macro_use] -extern crate serde; +extern crate tracing; #[macro_use] mod macros; @@ -82,7 +82,7 @@ pub type ParticipantTaskSignature = Signature; /// A dictionary created during the sum phase of the protocol. It maps the public key of every sum /// participant to the ephemeral public key generated by that sum participant. -type SumDict = HashMap; +pub type SumDict = HashMap; /// Local seed dictionaries are sent by update participants. They contain the participant's masking /// seed, encrypted with the ephemeral public key of each sum participant. diff --git a/rust/src/message/buffer.rs b/rust/src/message/buffer.rs new file mode 100644 index 000000000..f4710a9c2 --- /dev/null +++ b/rust/src/message/buffer.rs @@ -0,0 +1,367 @@ +use anyhow::{anyhow, Context}; +use std::ops::{Range, RangeFrom}; + +use crate::{ + message::{utils::range, DecodeError, Flags, LengthValueBuffer}, + CoordinatorPublicKey, + ParticipantPublicKey, +}; + +pub(crate) fn header_length(certificate_length: usize) -> usize { + certificate_length + PARTICIPANT_PK_RANGE.end +} + +// We currently only use 2 bits for the tag, so that byte could also +// be used for something else in the future. +const TAG_RANGE: usize = 0; +// Currently we only have one flag to indicate the presence of a +// certificate. +const FLAGS_RANGE: usize = 1; +// Reserve the remaining 2 bytes for future use. That also allows us +// to have 4 bytes alignment. +const RESERVED: Range = range(2, 2); +const COORDINATOR_PK_RANGE: Range = range(RESERVED.end, CoordinatorPublicKey::LENGTH); +const PARTICIPANT_PK_RANGE: Range = + range(COORDINATOR_PK_RANGE.end, ParticipantPublicKey::LENGTH); + +/// A wrapper around a buffer that contains a message. It provides +/// getters and setters to access the different fields of the message +/// safely. +/// +/// # Examples +/// +/// Reading a sum message: +/// +/// ```rust +/// use xain_fl::message::{Tag, Flags, MessageBuffer}; +/// use std::convert::TryFrom; +/// let mut bytes = vec![ +/// 0x01, // tag = 1 +/// 0x00, // flags = 0 +/// 0x00, 0x00, // reserved bytes, which are ignored +/// ]; +/// bytes.extend(vec![0xaa; 32]); // coordinator public key +/// bytes.extend(vec![0xbb; 32]); // participant public key +/// // Payload: a sum message contains a signature and an ephemeral public key +/// bytes.extend(vec![0x11; 32]); // signature +/// bytes.extend(vec![0x22; 32]); // public key +/// +/// let buffer = MessageBuffer::new(&bytes).unwrap(); +/// assert!(!buffer.has_certificate()); +/// assert_eq!(Tag::try_from(buffer.tag()).unwrap(), Tag::Sum); +/// assert_eq!(buffer.flags(), Flags::empty()); +/// assert!(buffer.certificate().is_none()); +/// assert_eq!(buffer.coordinator_pk(), vec![0xaa; 32].as_slice()); +/// assert_eq!(buffer.participant_pk(), vec![0xbb; 32].as_slice()); +/// assert_eq!(buffer.payload(), [vec![0x11; 32], vec![0x22; 32]].concat().as_slice()); +/// ``` +/// +/// Writing a sum message: +/// +/// ```rust +/// use xain_fl::message::{Tag, Flags, MessageBuffer}; +/// use std::convert::TryFrom; +/// let mut expected = vec![ +/// 0x01, // tag = 1 +/// 0x00, // flags = 0 +/// 0x00, 0x00, // reserved bytes, which are ignored +/// ]; +/// expected.extend(vec![0xaa; 32]); // coordinator public key +/// expected.extend(vec![0xbb; 32]); // participant public key +/// // Payload: a sum message contains a signature and an ephemeral public key +/// expected.extend(vec![0x11; 32]); // signature +/// expected.extend(vec![0x22; 32]); // public key +/// +/// let mut bytes = vec![0; expected.len()]; +/// let mut buffer = MessageBuffer::new_unchecked(&mut bytes); +/// buffer.set_tag(Tag::Sum.into()); +/// buffer.set_flags(Flags::empty()); +/// buffer +/// .coordinator_pk_mut() +/// .copy_from_slice(vec![0xaa; 32].as_slice()); +/// buffer +/// .participant_pk_mut() +/// .copy_from_slice(vec![0xbb; 32].as_slice()); +/// buffer +/// .payload_mut() +/// .copy_from_slice([vec![0x11; 32], vec![0x22; 32]].concat().as_slice()); +/// assert_eq!(expected, bytes); +/// ``` +pub struct MessageBuffer { + inner: T, +} + +impl> MessageBuffer { + /// Perform bound checks for the various message fields on `bytes` + /// and return a new `MessageBuffer`. + pub fn new(bytes: T) -> Result { + let buffer = Self { inner: bytes }; + buffer + .check_buffer_length() + .context("not a valid MessageBuffer")?; + Ok(buffer) + } + + /// Return a `MessageBuffer` without performing any bound + /// check. This means accessing the various fields may panic if + /// the data is invalid. + pub fn new_unchecked(bytes: T) -> Self { + Self { inner: bytes } + } + + /// Perform bound checks to ensure the fields can be accessed + /// without panicking. + pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + let len = self.inner.as_ref().len(); + // First, check the fixed size portion of the + // header. PARTICIPANT_PK_RANGE is the last field + if len < PARTICIPANT_PK_RANGE.end { + return Err(anyhow!( + "invalid buffer length: {} < {}", + len, + PARTICIPANT_PK_RANGE.end + )); + } + + // Check if the header contains a certificate, and if it does, + // check the length of certificate field. + if self.has_certificate() { + let bytes = &self.inner.as_ref()[PARTICIPANT_PK_RANGE.end..]; + let _ = + LengthValueBuffer::new(bytes).context("certificate field has an invalid lenth")?; + } + + Ok(()) + } + + /// Return whether this header contains a certificate + pub fn has_certificate(&self) -> bool { + self.flags().contains(Flags::CERTIFICATE) + } + + fn payload_range(&self) -> RangeFrom { + let certificate_length = self + .has_certificate() + .then(|| { + let bytes = &self.inner.as_ref()[PARTICIPANT_PK_RANGE.end..]; + LengthValueBuffer::new(bytes).unwrap().length() as usize + }) + .unwrap_or(0); + let payload_start = PARTICIPANT_PK_RANGE.end + certificate_length; + payload_start.. + } + + /// Get the tag field + pub fn tag(&self) -> u8 { + self.inner.as_ref()[TAG_RANGE] + } + + /// Get the flags field + pub fn flags(&self) -> Flags { + Flags::from_bits_truncate(self.inner.as_ref()[FLAGS_RANGE]) + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> MessageBuffer<&'a T> { + /// Get a slice to the certificate. If the header doesn't contain + /// any certificate, `None` is returned. + pub fn certificate(&self) -> Option> { + self.has_certificate().then(|| { + let bytes = &self.inner.as_ref()[PARTICIPANT_PK_RANGE.end..]; + LengthValueBuffer::new_unchecked(bytes) + }) + } + + /// Get the coordinator public key field + pub fn coordinator_pk(&self) -> &'a [u8] { + &self.inner.as_ref()[COORDINATOR_PK_RANGE] + } + + /// Get the participant public key field + pub fn participant_pk(&self) -> &'a [u8] { + &self.inner.as_ref()[PARTICIPANT_PK_RANGE] + } + + /// Get the rest of the message + pub fn payload(&self) -> &'a [u8] { + &self.inner.as_ref()[self.payload_range()] + } +} + +impl + AsRef<[u8]>> MessageBuffer { + /// Set the tag field + pub fn set_tag(&mut self, value: u8) { + self.inner.as_mut()[TAG_RANGE] = value; + } + + /// Set the flags field + pub fn set_flags(&mut self, value: Flags) { + self.inner.as_mut()[FLAGS_RANGE] = value.bits(); + } + + /// Get a mutable reference to the certificate field + pub fn certificate_mut(&mut self) -> Option> { + if self.has_certificate() { + let bytes = &mut self.inner.as_mut()[PARTICIPANT_PK_RANGE.end..]; + Some(LengthValueBuffer::new_unchecked(bytes)) + } else { + None + } + } + /// Get a mutable reference to the coordinator public key field + pub fn coordinator_pk_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[COORDINATOR_PK_RANGE] + } + + /// Get a mutable reference to the participant public key field + pub fn participant_pk_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[PARTICIPANT_PK_RANGE] + } + + /// Get a mutable reference to the rest of the message + pub fn payload_mut(&mut self) -> &mut [u8] { + let range = self.payload_range(); + &mut self.inner.as_mut()[range] + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::{ + certificate::Certificate, + crypto::ByteObject, + message::{sum, HeaderOwned, MessageOwned, Tag}, + }; + use std::convert::TryFrom; + + fn coordinator_pk() -> (Vec, CoordinatorPublicKey) { + let bytes = vec![0xaa; 32]; + let pk = CoordinatorPublicKey::from_slice(bytes.as_slice()).unwrap(); + (bytes, pk) + } + + fn participant_pk() -> (Vec, ParticipantPublicKey) { + let bytes = vec![0xbb; 32]; + let pk = ParticipantPublicKey::from_slice(&bytes).unwrap(); + (bytes, pk) + } + + fn certificate() -> (Vec, Certificate) { + let bytes = vec![0x01; 32]; + let cert = Certificate::try_from(bytes.as_slice()).unwrap(); + (bytes, cert) + } + + pub(crate) fn header_bytes(tag: Tag, with_certificate: bool) -> Vec { + let mut buf = vec![ + tag.into(), + // flags + if with_certificate { 1 } else { 0 }, + // reserved bytes, which can be anything + 0xff, + 0xff, + ]; + buf.extend(coordinator_pk().0); + buf.extend(participant_pk().0); + if with_certificate { + // certificate length + buf.extend(vec![0x00, 0x00, 0x00, 32 + 4]); + buf.extend(certificate().0); + } + buf + } + + pub(crate) fn header(tag: Tag, with_certificate: bool) -> HeaderOwned { + HeaderOwned { + tag, + coordinator_pk: coordinator_pk().1, + participant_pk: participant_pk().1, + certificate: if with_certificate { + Some(certificate().1) + } else { + None + }, + } + } + + fn sum(with_certificate: bool) -> (Vec, MessageOwned) { + let mut bytes = header_bytes(Tag::Sum, with_certificate); + bytes.extend(sum::tests::sum_bytes()); + + let header = header(Tag::Sum, with_certificate); + let payload = sum::tests::sum().into(); + let message = MessageOwned { header, payload }; + (bytes, message) + } + + #[test] + fn buffer_read_no_cert() { + let (bytes, _) = sum(false); + let buffer = MessageBuffer::new(&bytes).unwrap(); + assert!(!buffer.has_certificate()); + assert_eq!(Tag::try_from(buffer.tag()).unwrap(), Tag::Sum); + assert_eq!(buffer.flags(), Flags::empty()); + assert!(buffer.certificate().is_none()); + assert_eq!(buffer.coordinator_pk(), coordinator_pk().0.as_slice()); + assert_eq!(buffer.participant_pk(), participant_pk().0.as_slice()); + } + + #[test] + fn buffer_read_with_cert() { + let (bytes, _) = sum(true); + let buffer = MessageBuffer::new(&bytes).unwrap(); + assert!(buffer.has_certificate()); + assert_eq!(Tag::try_from(buffer.tag()).unwrap(), Tag::Sum); + assert_eq!(buffer.flags(), Flags::CERTIFICATE); + assert_eq!(buffer.certificate().unwrap().value(), &certificate().0[..]); + assert_eq!(buffer.coordinator_pk(), coordinator_pk().0.as_slice()); + assert_eq!(buffer.participant_pk(), participant_pk().0.as_slice()); + } + + #[test] + fn buffer_write_no_cert() { + let expected = sum(false).0; + let mut bytes = vec![0xff; expected.len()]; + let mut buffer = MessageBuffer::new_unchecked(&mut bytes); + + buffer.set_tag(Tag::Sum.into()); + buffer.set_flags(Flags::empty()); + buffer + .coordinator_pk_mut() + .copy_from_slice(coordinator_pk().0.as_slice()); + buffer + .participant_pk_mut() + .copy_from_slice(participant_pk().0.as_slice()); + buffer + .payload_mut() + .copy_from_slice(sum::tests::sum_bytes().as_slice()); + assert_eq!(bytes, expected); + } + + #[test] + fn buffer_write_with_cert() { + let expected = sum(true).0; + let mut bytes = vec![0xff; expected.len()]; + let mut buffer = MessageBuffer::new_unchecked(&mut bytes); + + buffer.set_tag(Tag::Sum.into()); + buffer.set_flags(Flags::CERTIFICATE); + buffer + .coordinator_pk_mut() + .copy_from_slice(coordinator_pk().0.as_slice()); + buffer + .participant_pk_mut() + .copy_from_slice(participant_pk().0.as_slice()); + buffer.certificate_mut().unwrap().set_length(32 + 4); + buffer + .certificate_mut() + .unwrap() + .value_mut() + .copy_from_slice(certificate().0.as_slice()); + buffer + .payload_mut() + .copy_from_slice(sum::tests::sum_bytes().as_slice()); + assert_eq!(bytes, expected); + } +} diff --git a/rust/src/message/header.rs b/rust/src/message/header.rs new file mode 100644 index 000000000..f3391eeda --- /dev/null +++ b/rust/src/message/header.rs @@ -0,0 +1,116 @@ +use anyhow::{anyhow, Context}; +use std::{borrow::Borrow, convert::TryFrom}; + +use crate::{ + certificate::Certificate, + message::{header_length, DecodeError, FromBytes, MessageBuffer, ToBytes}, + CoordinatorPublicKey, + ParticipantPublicKey, +}; + +#[derive(Copy, Debug, Clone, Eq, PartialEq)] +/// Tag that indicates the type of message +pub enum Tag { + /// Tag for sum messages + Sum, + /// Tag for update messages + Update, + /// Tag for sum2 messages + Sum2, +} + +impl TryFrom for Tag { + type Error = DecodeError; + + fn try_from(value: u8) -> Result { + Ok(match value { + 1 => Tag::Sum, + 2 => Tag::Update, + 3 => Tag::Sum2, + _ => return Err(anyhow!("invalid tag {}", value)), + }) + } +} + +impl Into for Tag { + fn into(self) -> u8 { + match self { + Tag::Sum => 1, + Tag::Update => 2, + Tag::Sum2 => 3, + } + } +} + +const CERTIFICATE_FLAG: u8 = 0; + +bitflags::bitflags! { + /// Bitmask that defines flags for a message + pub struct Flags: u8 { + /// Indicates the presence of a client certificate in the + /// message + const CERTIFICATE = 1 << CERTIFICATE_FLAG; + } +} + +/// A header common to all the messages +pub struct Header { + /// Type of message + pub tag: Tag, + /// Coordinator public key + pub coordinator_pk: CoordinatorPublicKey, + /// Participant public key + pub participant_pk: ParticipantPublicKey, + /// A certificate that identifies the author of the message + pub certificate: Option, +} + +impl ToBytes for Header +where + C: Borrow, +{ + fn buffer_length(&self) -> usize { + let cert_length = self + .certificate + .as_ref() + .map(|cert| cert.borrow().buffer_length()) + .unwrap_or(0); + header_length(cert_length) + } + + fn to_bytes>(&self, buffer: &mut T) { + let mut writer = MessageBuffer::new(buffer.as_mut()).unwrap(); + writer.set_tag(self.tag.into()); + if self.certificate.is_some() { + writer.set_flags(Flags::CERTIFICATE); + } else { + writer.set_flags(Flags::empty()); + } + self.coordinator_pk + .to_bytes(&mut writer.coordinator_pk_mut()); + self.participant_pk + .to_bytes(&mut writer.participant_pk_mut()); + } +} + +/// Owned version of a [`Header`] +pub type HeaderOwned = Header; + +impl FromBytes for HeaderOwned { + fn from_bytes>(buffer: &T) -> Result { + let reader = MessageBuffer::new(buffer.as_ref())?; + let certificate = if let Some(bytes) = reader.certificate() { + Some(Certificate::from_bytes(&bytes.value())?) + } else { + None + }; + Ok(Self { + tag: Tag::try_from(reader.tag())?, + coordinator_pk: CoordinatorPublicKey::from_bytes(&reader.coordinator_pk()) + .context("invalid coordinator public key")?, + participant_pk: ParticipantPublicKey::from_bytes(&reader.participant_pk()) + .context("invalid participant public key")?, + certificate, + }) + } +} diff --git a/rust/src/message/message.rs b/rust/src/message/message.rs new file mode 100644 index 000000000..0db0ce6f5 --- /dev/null +++ b/rust/src/message/message.rs @@ -0,0 +1,182 @@ +use anyhow::{anyhow, Context}; +use std::borrow::Borrow; + +use crate::{ + certificate::Certificate, + crypto::{ByteObject, PublicEncryptKey, SecretEncryptKey, SecretSigningKey, Signature}, + mask::{Mask, MaskedModel}, + message::{ + DecodeError, + FromBytes, + Header, + HeaderOwned, + Payload, + PayloadOwned, + Sum2Owned, + SumOwned, + Tag, + ToBytes, + UpdateOwned, + }, + LocalSeedDict, +}; + +/// A message +pub struct Message { + /// Message header + pub header: Header, + /// Message payload + pub payload: Payload, +} + +pub type MessageOwned = Message; + +macro_rules! impl_new { + ($name:ident, $payload:ty, $tag:expr) => { + paste::item! { + pub fn []( + coordinator_pk: $crate::CoordinatorPublicKey, + participant_pk: $crate::ParticipantPublicKey, + payload: $payload) -> Self + { + Self { + header: Header { + coordinator_pk, + participant_pk, + tag: $tag, + certificate: None, + }, + payload: $crate::message::Payload::from(payload), + } + } + } + }; +} + +impl Message +where + C: Borrow, + D: Borrow, + M: Borrow, + N: Borrow, +{ + impl_new!(sum, crate::message::Sum, Tag::Sum); + impl_new!(update, crate::message::Update, Tag::Update); + impl_new!(sum2, crate::message::Sum2, Tag::Sum2); +} + +impl ToBytes for Message +where + C: Borrow, + D: Borrow, + M: Borrow, + N: Borrow, +{ + fn buffer_length(&self) -> usize { + self.header.buffer_length() + self.payload.buffer_length() + } + + fn to_bytes>(&self, buffer: &mut T) { + self.header.to_bytes(buffer); + let mut payload_slice = &mut buffer.as_mut()[self.header.buffer_length()..]; + self.payload.to_bytes(&mut payload_slice); + } +} + +impl FromBytes for MessageOwned { + fn from_bytes>(buffer: &T) -> Result { + let header = HeaderOwned::from_bytes(&buffer)?; + let payload_slice = &buffer.as_ref()[header.buffer_length()..]; + let payload = match header.tag { + Tag::Sum => PayloadOwned::Sum( + SumOwned::from_bytes(&payload_slice).context("invalid sum payload")?, + ), + Tag::Update => PayloadOwned::Update( + UpdateOwned::from_bytes(&payload_slice).context("invalid update payload")?, + ), + Tag::Sum2 => PayloadOwned::Sum2( + Sum2Owned::from_bytes(&payload_slice).context("invalid sum2 payload")?, + ), + }; + Ok(Self { header, payload }) + } +} + +/// A seal to sign and encrypt messages +pub struct MessageSeal<'a, 'b> { + /// Public key of the recipient, used to encrypt messages + pub recipient_pk: &'a PublicEncryptKey, + /// Secret key of the sender, used to sign messages + pub sender_sk: &'b SecretSigningKey, +} + +impl<'a, 'b> MessageSeal<'a, 'b> { + /// Sign and encrypt the given message + pub fn seal(&self, message: &Message) -> Vec + where + C: Borrow, + D: Borrow, + M: Borrow, + N: Borrow, + { + let signed_message = self.sign(&message); + self.recipient_pk.encrypt(&signed_message[..]) + } + + /// Sign the given message + fn sign(&self, message: &Message) -> Vec + where + C: Borrow, + D: Borrow, + M: Borrow, + N: Borrow, + { + let signed_payload_length = message.buffer_length() + Signature::LENGTH; + + let mut buffer = Vec::with_capacity(signed_payload_length); + message.to_bytes(&mut &mut buffer[Signature::LENGTH..]); + + let signature = self.sender_sk.sign_detached(&buffer[Signature::LENGTH..]); + signature.to_bytes(&mut &mut buffer[..Signature::LENGTH]); + + buffer + } +} + +/// A message opener that decrypts a message and verifies its signature +pub struct MessageOpen<'a, 'b> { + /// Secret key for decrypting the message + pub recipient_sk: &'b SecretEncryptKey, + /// Public key for decrypting the message + pub recipient_pk: &'a PublicEncryptKey, +} + +impl<'a, 'b> MessageOpen<'a, 'b> { + pub fn open>(&self, buffer: &T) -> Result { + // Step 1: decrypt the message + let bytes = self + .recipient_sk + .decrypt(buffer.as_ref(), self.recipient_pk) + .map_err(|_| anyhow!("invalid message: failed to decrypt message"))?; + + if bytes.len() < Signature::LENGTH { + return Err(anyhow!("invalid message: invalid length")); + } + + // UNWRAP_SAFE: the slice is exactly the size from_slice + // expects. + let signature = Signature::from_slice(&bytes[..Signature::LENGTH]).unwrap(); + + let message_bytes = &bytes[Signature::LENGTH..]; + let message = + MessageOwned::from_bytes(&message_bytes).context("invalid message: parsing failed")?; + if !message + .header + .participant_pk + .verify_detached(&signature, message_bytes) + { + return Err(anyhow!("invalid message: invalid signature")); + } + Ok(message) + } +} diff --git a/rust/src/message/mod.rs b/rust/src/message/mod.rs index 0056729a6..25b48c041 100644 --- a/rust/src/message/mod.rs +++ b/rust/src/message/mod.rs @@ -1,13 +1,10 @@ -pub mod sum; -pub mod sum2; -pub mod update; +pub(crate) mod utils; -use std::{ - mem, - ops::{Range, RangeFrom, RangeTo}, -}; +mod traits; +pub use self::traits::{FromBytes, LengthValueBuffer, ToBytes}; -use sodiumoxide::crypto::{box_, sign}; +mod buffer; +pub use self::buffer::*; #[repr(u8)] /// Message tags. @@ -19,119 +16,12 @@ enum Tag { Sum2, } -/// Get the number of bytes of a signature field. -const SIGNATURE_BYTES: usize = sign::SIGNATUREBYTES; +pub(crate) mod payload; +pub use self::payload::*; -/// Get the number of bytes of a message tag field. -const TAG_BYTES: usize = 1; +mod message; +pub use self::message::*; -/// Get the number of bytes of a public key field. -const PK_BYTES: usize = box_::PUBLICKEYBYTES; - -/// Get the number of bytes of a length field. -const LEN_BYTES: usize = mem::size_of::(); - -trait MessageBuffer { - /// Get the range of the signature field. - const SIGNATURE_RANGE: RangeTo = ..SIGNATURE_BYTES; - - /// Get the range of the message field. - const MESSAGE_RANGE: RangeFrom = Self::SIGNATURE_RANGE.end..; - - /// Get the range of the tag field. - const TAG_RANGE: Range = - Self::SIGNATURE_RANGE.end..Self::SIGNATURE_RANGE.end + TAG_BYTES; - - /// Get the range of the coordinator public key field. - const COORD_PK_RANGE: Range = Self::TAG_RANGE.end..Self::TAG_RANGE.end + PK_BYTES; - - /// Get the range of the participant public key field. - const PART_PK_RANGE: Range = - Self::COORD_PK_RANGE.end..Self::COORD_PK_RANGE.end + PK_BYTES; - - /// Get the range of the sum signature field. - const SUM_SIGNATURE_RANGE: Range = - Self::PART_PK_RANGE.end..Self::PART_PK_RANGE.end + SIGNATURE_BYTES; - - /// Get a reference to the message buffer. - fn bytes(&'_ self) -> &'_ [u8]; - - /// Get a mutable reference to the message buffer. - fn bytes_mut(&mut self) -> &mut [u8]; - - /// Get the length of the message buffer. - fn len(&self) -> usize { - self.bytes().len() - } - - /// Get a reference to the signature field. - fn signature(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::SIGNATURE_RANGE] - } - - /// Get a mutable reference to the signature field. - fn signature_mut(&mut self) -> &mut [u8] { - &mut self.bytes_mut()[Self::SIGNATURE_RANGE] - } - - /// Get a reference to the message field. - fn message(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::MESSAGE_RANGE] - } - - /// Get a reference to the tag field. - fn tag(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::TAG_RANGE] - } - - /// Get a mutable reference to the tag field. - fn tag_mut(&mut self) -> &mut [u8] { - &mut self.bytes_mut()[Self::TAG_RANGE] - } - - /// Get a reference to the coordinator public key field. - fn coord_pk(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::COORD_PK_RANGE] - } - - /// Get a mutable reference to the coordinator public key field. - fn coord_pk_mut(&mut self) -> &mut [u8] { - &mut self.bytes_mut()[Self::COORD_PK_RANGE] - } - - /// Get a reference to the participant public key field. - fn part_pk(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::PART_PK_RANGE] - } - - /// Get a mutable reference to the participant public key field. - fn part_pk_mut(&mut self) -> &mut [u8] { - &mut self.bytes_mut()[Self::PART_PK_RANGE] - } - - /// Get a reference to the sum signature field. - fn sum_signature(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::SUM_SIGNATURE_RANGE] - } - - /// Get a mutable reference to the sum signature field. - fn sum_signature_mut(&mut self) -> &mut [u8] { - &mut self.bytes_mut()[Self::SUM_SIGNATURE_RANGE] - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_constants() { - // just to make sure that the constants were not changed accidentally, because a lot of - // assumptions are based on those - assert_eq!(SIGNATURE_BYTES, sign::SIGNATUREBYTES); - assert_eq!(TAG_BYTES, 1); - assert_eq!(PK_BYTES, box_::PUBLICKEYBYTES); - assert_eq!(PK_BYTES, sign::PUBLICKEYBYTES); - assert_eq!(LEN_BYTES, mem::size_of::()); - } -} +/// Error that signals a failure when trying to decrypt and parse a +/// message +pub type DecodeError = anyhow::Error; diff --git a/rust/src/message/payload/mod.rs b/rust/src/message/payload/mod.rs new file mode 100644 index 000000000..55b6712b9 --- /dev/null +++ b/rust/src/message/payload/mod.rs @@ -0,0 +1,52 @@ +use std::borrow::Borrow; + +pub(crate) mod sum; +pub use self::sum::*; +pub(crate) mod sum2; +pub use self::sum2::*; +pub(crate) mod update; +pub use self::update::*; + +use derive_more::From; + +use crate::{ + mask::{Mask, MaskedModel}, + message::traits::ToBytes, + LocalSeedDict, +}; + +/// Payload of a [`Message`] +#[derive(From, Eq, PartialEq, Clone, Debug)] +pub enum Payload { + /// Payload of a sum message + Sum(Sum), + /// Payload of an update message + Update(Update), + /// Payload of a sum2 message + Sum2(Sum2), +} + +pub type PayloadOwned = Payload; + +impl ToBytes for Payload +where + D: Borrow, + M: Borrow, + N: Borrow, +{ + fn buffer_length(&self) -> usize { + match self { + Payload::Sum(m) => m.buffer_length(), + Payload::Sum2(m) => m.buffer_length(), + Payload::Update(m) => m.buffer_length(), + } + } + + fn to_bytes>(&self, buffer: &mut T) { + match self { + Payload::Sum(m) => m.to_bytes(buffer), + Payload::Sum2(m) => m.to_bytes(buffer), + Payload::Update(m) => m.to_bytes(buffer), + } + } +} diff --git a/rust/src/message/payload/sum.rs b/rust/src/message/payload/sum.rs new file mode 100644 index 000000000..632d6a40c --- /dev/null +++ b/rust/src/message/payload/sum.rs @@ -0,0 +1,287 @@ +use anyhow::{anyhow, Context}; +use std::ops::Range; + +use crate::{ + message::{utils::range, DecodeError, FromBytes, ToBytes}, + ParticipantTaskSignature, + SumParticipantEphemeralPublicKey, +}; + +const SUM_SIGNATURE_RANGE: Range = range(0, ParticipantTaskSignature::LENGTH); +const EPHM_PK_RANGE: Range = range( + SUM_SIGNATURE_RANGE.end, + SumParticipantEphemeralPublicKey::LENGTH, +); + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +/// A wrapper around a buffer that contains a sum message. It provides +/// getters and setters to access the different fields of the message +/// safely. +/// +/// # Examples +/// +/// Decoding a sum message: +/// +/// ```rust +/// # use xain_fl::message::SumBuffer; +/// let sum_signature = vec![0x11; 64]; +/// let ephm_pk = vec![0x22; 32]; +/// let bytes = [sum_signature.as_slice(), ephm_pk.as_slice()].concat(); +/// let buffer = SumBuffer::new(&bytes).unwrap(); +/// assert_eq!(buffer.sum_signature(), sum_signature.as_slice()); +/// assert_eq!(buffer.ephm_pk(), ephm_pk.as_slice()); +/// ``` +/// +/// Encoding a sum message: +/// +/// ```rust +/// # use xain_fl::message::SumBuffer; +/// let sum_signature = vec![0x11; 64]; +/// let ephm_pk = vec![0x22; 32]; +/// let mut storage = vec![0xff; 96]; +/// let mut buffer = SumBuffer::new_unchecked(&mut storage); +/// buffer +/// .sum_signature_mut() +/// .copy_from_slice(&sum_signature[..]); +/// buffer.ephm_pk_mut().copy_from_slice(&ephm_pk[..]); +/// assert_eq!(&storage[..64], sum_signature.as_slice()); +/// assert_eq!(&storage[64..], ephm_pk.as_slice()); +/// ``` +pub struct SumBuffer { + inner: T, +} + +impl> SumBuffer { + /// Perform bound checks for the various message fields on `bytes` + /// and return a new `SumBuffer`. + pub fn new(bytes: T) -> Result { + let buffer = Self { inner: bytes }; + buffer + .check_buffer_length() + .context("not a valid SumBuffer")?; + Ok(buffer) + } + + /// Return a `SumBuffer` without performing any bound + /// check. This means accessing the various fields may panic if + /// the data is invalid. + pub fn new_unchecked(bytes: T) -> Self { + Self { inner: bytes } + } + + /// Perform bound checks to ensure the fields can be accessed + /// without panicking. + pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + let len = self.inner.as_ref().len(); + if len < EPHM_PK_RANGE.end { + Err(anyhow!( + "invalid buffer length: {} < {}", + len, + EPHM_PK_RANGE.end + )) + } else { + Ok(()) + } + } +} + +impl> SumBuffer { + /// Get a mutable reference to the sum participant ephemeral + /// public key field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn ephm_pk_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[EPHM_PK_RANGE] + } + + /// Get a mutable reference to the sum signature field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn sum_signature_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE] + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> SumBuffer<&'a T> { + /// Get a reference to the sum participant ephemeral public key + /// field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn ephm_pk(&self) -> &'a [u8] { + &self.inner.as_ref()[EPHM_PK_RANGE] + } + + /// Get a reference to the sum signature field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn sum_signature(&self) -> &'a [u8] { + &self.inner.as_ref()[SUM_SIGNATURE_RANGE] + } +} + +/// High level representation of a sum message. These messages are +/// sent by sum participants during the sum phase. +/// +/// # Examples +/// +/// ## Decoding a message +/// +/// ```rust +/// # use xain_fl::{crypto::ByteObject, message::{FromBytes, SumOwned}, ParticipantTaskSignature, SumParticipantEphemeralPublicKey}; +/// let signature = vec![0x11; 64]; +/// let ephm_pk = vec![0x22; 32]; +/// let bytes = [signature.as_slice(), ephm_pk.as_slice()].concat(); +/// let parsed = SumOwned::from_bytes(&bytes).unwrap(); +/// let expected = SumOwned { +/// sum_signature: ParticipantTaskSignature::from_slice(&signature[..]).unwrap(), +/// ephm_pk: SumParticipantEphemeralPublicKey::from_slice(&ephm_pk[..]).unwrap(), +/// }; +/// assert_eq!(parsed, expected); +/// ``` +/// +/// ## Encoding a message +/// +/// ```rust +/// # use xain_fl::{crypto::ByteObject, message::{ToBytes, Sum}, ParticipantTaskSignature, SumParticipantEphemeralPublicKey}; +/// let sum_signature = ParticipantTaskSignature::from_slice(vec![0x11; 64].as_slice()).unwrap(); +/// let ephm_pk = SumParticipantEphemeralPublicKey::from_slice(vec![0x22; 32].as_slice()).unwrap(); +/// let msg = Sum { +/// sum_signature, +/// ephm_pk, +/// }; +/// // we need a 96 bytes long buffer to serialize that message +/// assert_eq!(msg.buffer_length(), 96); +/// // create a buffer with enough space and encode the message +/// let mut buf = vec![0xff; 96]; +/// msg.to_bytes(&mut buf); +/// +/// assert_eq!(buf, [vec![0x11; 64].as_slice(), vec![0x22; 32].as_slice()].concat()); +/// ``` +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct Sum { + /// Signature of the round seed and the word "sum", used to + /// determine whether a participant is selected for the sum task + pub sum_signature: ParticipantTaskSignature, + /// A public key generated by a sum participant for the current + /// round. + pub ephm_pk: SumParticipantEphemeralPublicKey, +} + +pub type SumOwned = Sum; + +impl ToBytes for Sum { + fn buffer_length(&self) -> usize { + EPHM_PK_RANGE.end + } + + fn to_bytes>(&self, buffer: &mut T) { + let mut writer = SumBuffer::new(buffer.as_mut()).unwrap(); + self.sum_signature.to_bytes(&mut writer.sum_signature_mut()); + self.ephm_pk.to_bytes(&mut writer.ephm_pk_mut()); + } +} + +impl FromBytes for SumOwned { + /// Deserialize a sum message from a buffer. + fn from_bytes>(buffer: &T) -> Result { + let reader = SumBuffer::new(buffer.as_ref())?; + + let sum_signature = ParticipantTaskSignature::from_bytes(&reader.sum_signature()) + .context("invalid sum signature")?; + + let ephm_pk = SumParticipantEphemeralPublicKey::from_bytes(&reader.ephm_pk()) + .context("invalid ephemeral public key")?; + + Ok(Self { + sum_signature, + ephm_pk, + }) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::crypto::ByteObject; + + fn sum_signature_bytes() -> Vec { + vec![0x11; ParticipantTaskSignature::LENGTH] + } + + fn ephm_pk_bytes() -> Vec { + vec![0x22; SumParticipantEphemeralPublicKey::LENGTH] + } + + pub(crate) fn sum_bytes() -> Vec { + [sum_signature_bytes().as_slice(), ephm_pk_bytes().as_slice()].concat() + } + + pub(crate) fn sum() -> SumOwned { + let sum_signature = + ParticipantTaskSignature::from_slice(&sum_signature_bytes()[..]).unwrap(); + let ephm_pk = SumParticipantEphemeralPublicKey::from_slice(&ephm_pk_bytes()).unwrap(); + SumOwned { + sum_signature, + ephm_pk, + } + } + + #[test] + fn buffer_read() { + let bytes = sum_bytes(); + let buffer = SumBuffer::new(&bytes).unwrap(); + assert_eq!(buffer.sum_signature(), &sum_signature_bytes()[..]); + assert_eq!(buffer.ephm_pk(), &ephm_pk_bytes()[..]); + } + + #[test] + fn buffer_read_invalid() { + let bytes = sum_bytes(); + assert!(SumBuffer::new(&bytes[1..]).is_err()); + } + + #[test] + fn buffer_write() { + let mut buffer = vec![0xff; EPHM_PK_RANGE.end]; + let mut writer = SumBuffer::new_unchecked(&mut buffer); + writer + .sum_signature_mut() + .copy_from_slice(sum_signature_bytes().as_slice()); + writer + .ephm_pk_mut() + .copy_from_slice(ephm_pk_bytes().as_slice()); + } + + #[test] + fn encode() { + let message = sum(); + assert_eq!(message.buffer_length(), sum_bytes().len()); + + let mut buf = vec![0xff; message.buffer_length()]; + message.to_bytes(&mut buf); + assert_eq!(buf, sum_bytes()); + } + + #[test] + fn decode() { + let parsed = SumOwned::from_bytes(&sum_bytes()).unwrap(); + let expected = sum(); + assert_eq!(parsed, expected); + } +} diff --git a/rust/src/message/payload/sum2.rs b/rust/src/message/payload/sum2.rs new file mode 100644 index 000000000..ee0c4dac2 --- /dev/null +++ b/rust/src/message/payload/sum2.rs @@ -0,0 +1,304 @@ +use std::{borrow::Borrow, ops::Range}; + +use crate::{ + mask::Mask, + message::{utils::range, DecodeError, FromBytes, LengthValueBuffer, ToBytes}, + ParticipantTaskSignature, +}; +use anyhow::{anyhow, Context}; + +const SUM_SIGNATURE_RANGE: Range = range(0, ParticipantTaskSignature::LENGTH); + +/// A wrapper around a buffer that contains a sum2 message. It provides +/// getters and setters to access the different fields of the message +/// safely. +/// +/// # Examples +/// +/// Decoding a sum2 message: +/// +/// ```rust +/// # use xain_fl::message::Sum2Buffer; +/// let signature = vec![0x11; 64]; +/// let mask = vec![ +/// 0x00, 0x00, 0x00, 0x08, // Length 8 +/// 0x00, 0x01, 0x02, 0x03, // Value: 0, 1, 2, 3 +/// ]; +/// let bytes = [signature.as_slice(), mask.as_slice()].concat(); +/// +/// let buffer = Sum2Buffer::new(&bytes).unwrap(); +/// assert_eq!(buffer.sum_signature(), &bytes[..64]); +/// assert_eq!(buffer.mask(), &bytes[64..]); +/// ``` +/// +/// Encoding a sum2 message: +/// +/// ```rust +/// # use xain_fl::message::Sum2Buffer; +/// let signature = vec![0x11; 64]; +/// let mask = vec![ +/// 0x00, 0x00, 0x00, 0x08, // Length 8 +/// 0x00, 0x01, 0x02, 0x03, // Value: 0, 1, 2, 3 +/// ]; +/// let mut bytes = vec![0xff; 72]; +/// { +/// let mut buffer = Sum2Buffer::new_unchecked(&mut bytes); +/// buffer.sum_signature_mut().copy_from_slice(&signature[..]); +/// buffer.mask_mut().copy_from_slice(&mask[..]); +/// } +/// assert_eq!(&bytes[..64], &signature[..]); +/// assert_eq!(&bytes[64..], &mask[..]); +/// ``` +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct Sum2Buffer { + inner: T, +} + +impl> Sum2Buffer { + /// Perform bound checks for the various message fields on `bytes` + /// and return a new `Sum2Buffer`. + pub fn new(bytes: T) -> Result { + let buffer = Self { inner: bytes }; + buffer + .check_buffer_length() + .context("not a valid Sum2Buffer")?; + Ok(buffer) + } + + /// Return a `Sum2Buffer` with the given `bytes` without + /// performing bound checks. This means that accessing the message + /// fields may panic. + pub fn new_unchecked(bytes: T) -> Self { + Self { inner: bytes } + } + + /// Perform bound checks for the various message fields on this + /// buffer. + pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + let len = self.inner.as_ref().len(); + if len < SUM_SIGNATURE_RANGE.end { + return Err(anyhow!( + "invalid buffer length: {} < {}", + len, + SUM_SIGNATURE_RANGE.end + )); + } + + LengthValueBuffer::new(&self.inner.as_ref()[SUM_SIGNATURE_RANGE.end..])?; + Ok(()) + } +} + +impl> Sum2Buffer { + /// Get a mutable reference to the sum signature field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid sum2 message. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn sum_signature_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE] + } + + /// Get a mutable reference to the mask field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid sum2 message. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn mask_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE.end..] + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Sum2Buffer<&'a T> { + /// Get a reference to the sum signature field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid sum2 message. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn sum_signature(&self) -> &'a [u8] { + &self.inner.as_ref()[SUM_SIGNATURE_RANGE] + } + + /// Get a reference to the mask field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid sum2 message. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn mask(&self) -> &'a [u8] { + &self.inner.as_ref()[SUM_SIGNATURE_RANGE.end..] + } +} + +/// High level representation of a sum2 message. These messages are +/// sent by sum participants during the sum2 phase. +/// +/// # Examples +/// +/// ## Decoding a message +/// +/// ```rust +/// # use xain_fl::{crypto::ByteObject, message::{FromBytes, Sum2Owned}, ParticipantTaskSignature, mask::Mask}; +/// let signature = vec![0x11; 64]; +/// let mask = vec![ +/// 0x00, 0x00, 0x00, 0x08, // Length 8 +/// 0x00, 0x01, 0x02, 0x03, // Value: 0, 1, 2, 3 +/// ]; +/// let bytes = [signature.as_slice(), mask.as_slice()].concat(); +/// let parsed = Sum2Owned::from_bytes(&bytes).unwrap(); +/// let expected = Sum2Owned { +/// sum_signature: ParticipantTaskSignature::from_slice(&bytes[..64]).unwrap(), +/// mask: Mask::from(&[0, 1, 2, 3][..]), +/// }; +/// assert_eq!(parsed, expected); +/// ``` +/// +/// ## Encoding a message +/// +/// ```rust +/// # use xain_fl::{crypto::ByteObject, message::{ToBytes, Sum2Owned}, ParticipantTaskSignature, mask::Mask}; +/// let signature = vec![0x11; 64]; +/// let mask = vec![ +/// 0x00, 0x00, 0x00, 0x08, // Length 8 +/// 0x00, 0x01, 0x02, 0x03, // Value: 0, 1, 2, 3 +/// ]; +/// let bytes = [signature.as_slice(), mask.as_slice()].concat(); +/// +/// let sum_signature = ParticipantTaskSignature::from_slice(&bytes[..64]).unwrap(); +/// let mask = Mask::from(&[0, 1, 2, 3][..]); +/// let sum2 = Sum2Owned { +/// sum_signature, +/// mask, +/// }; +/// // we need a 72 bytes long buffer to serialize that message +/// assert_eq!(sum2.buffer_length(), 72); +/// let mut buf = vec![0xff; 72]; +/// sum2.to_bytes(&mut buf); +/// assert_eq!(bytes, buf); +/// ``` +#[derive(Eq, PartialEq, Clone, Debug)] +pub struct Sum2 { + /// Signature of the word "sum", using the participant's secret + /// signing key. This is used by the coordinator to verify that + /// the participant has been selected to perform the sum task. + pub sum_signature: ParticipantTaskSignature, + + /// Mask computed by the participant. + pub mask: M, +} + +impl ToBytes for Sum2 +where + M: Borrow, +{ + fn buffer_length(&self) -> usize { + SUM_SIGNATURE_RANGE.end + self.mask.borrow().buffer_length() + } + + fn to_bytes>(&self, buffer: &mut T) { + let mut writer = Sum2Buffer::new_unchecked(buffer.as_mut()); + self.sum_signature.to_bytes(&mut writer.sum_signature_mut()); + self.mask.borrow().to_bytes(&mut writer.mask_mut()); + } +} + +/// Owned version of a [`Sum2`] +pub type Sum2Owned = Sum2; + +impl FromBytes for Sum2Owned { + fn from_bytes>(buffer: &T) -> Result { + let reader = Sum2Buffer::new(buffer.as_ref())?; + Ok(Self { + sum_signature: ParticipantTaskSignature::from_bytes(&reader.sum_signature()) + .context("invalid sum signature")?, + mask: Mask::from_bytes(&reader.mask()).context("invalid mask")?, + }) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::{crypto::ByteObject, mask::Mask}; + + fn signature_bytes() -> Vec { + vec![0x99; ParticipantTaskSignature::LENGTH] + } + + fn mask_bytes() -> Vec { + vec![ + 0x00, 0x00, 0x00, 0x08, // Length 8 + 0x00, 0x01, 0x02, 0x03, // Value: 0, 1, 2, 3 + ] + } + + fn sum2_bytes() -> Vec { + let mut bytes = signature_bytes(); + bytes.extend(mask_bytes()); + bytes + } + + fn sum2() -> Sum2Owned { + let sum_signature = ParticipantTaskSignature::from_slice(&signature_bytes()[..]).unwrap(); + let mask = Mask::from(&[0, 1, 2, 3][..]); + Sum2Owned { + sum_signature, + mask, + } + } + + #[test] + fn buffer_read() { + let bytes = sum2_bytes(); + let buffer = Sum2Buffer::new(&bytes).unwrap(); + assert_eq!(buffer.sum_signature(), &signature_bytes()[..]); + assert_eq!(buffer.mask(), &mask_bytes()[..]); + } + + #[test] + fn buffer_new_invalid() { + let mut bytes = sum2_bytes(); + assert!(Sum2Buffer::new(&bytes[1..]).is_err()); + // make the length field for the mask invalid + bytes[66] = 1; + assert!(Sum2Buffer::new(&bytes[..]).is_err()); + } + + #[test] + fn buffer_write() { + let mut bytes = vec![0xff; 72]; + { + let mut buffer = Sum2Buffer::new_unchecked(&mut bytes); + buffer + .sum_signature_mut() + .copy_from_slice(&signature_bytes()[..]); + buffer.mask_mut().copy_from_slice(&mask_bytes()[..]); + } + assert_eq!(&bytes[..], &sum2_bytes()[..]); + } + + #[test] + fn encode() { + let message = sum2(); + assert_eq!(message.buffer_length(), sum2_bytes().len()); + + let mut buf = vec![0xff; message.buffer_length()]; + message.to_bytes(&mut buf); + assert_eq!(buf, sum2_bytes()); + } + + #[test] + fn decode() { + let bytes = sum2_bytes(); + let parsed = Sum2Owned::from_bytes(&bytes).unwrap(); + let expected = sum2(); + assert_eq!(parsed, expected); + } +} diff --git a/rust/src/message/payload/update.rs b/rust/src/message/payload/update.rs new file mode 100644 index 000000000..87963fd26 --- /dev/null +++ b/rust/src/message/payload/update.rs @@ -0,0 +1,433 @@ +use crate::{ + mask::MaskedModel, + message::{utils::range, DecodeError, FromBytes, LengthValueBuffer, ToBytes}, + LocalSeedDict, + ParticipantTaskSignature, +}; +use anyhow::{anyhow, Context}; +use std::{borrow::Borrow, ops::Range}; + +const SUM_SIGNATURE_RANGE: Range = range(0, ParticipantTaskSignature::LENGTH); +const UPDATE_SIGNATURE_RANGE: Range = + range(SUM_SIGNATURE_RANGE.end, ParticipantTaskSignature::LENGTH); + +#[derive(Clone, Debug)] +/// Wrapper around a buffer that contains an update message. +/// +/// # Example +/// +/// ```rust +/// # use xain_fl::message::UpdateBuffer; +/// let sum_signature = vec![0x11; 64]; +/// let update_signature = vec![0x22; 64]; +/// let masked_model = vec![ +/// 0x00, 0x00, 0x00, 0x08, // Length 8 +/// 0x00, 0x01, 0x02, 0x03, // Value: 0, 1, 2, 3 +/// ]; +/// let mut local_seed_dict = vec![]; +/// // Length field: ((32 + 80) * 2) + 4 = 228 +/// local_seed_dict.extend(vec![0x00, 0x00, 0x00, 0xe4]); +/// // first entry: a key (32 bytes), and an encrypted mask seed (80 bytes) +/// local_seed_dict.extend(vec![0x33; 32]); +/// local_seed_dict.extend(vec![0x44; 80]); +/// // second entry +/// local_seed_dict.extend(vec![0x55; 32]); +/// local_seed_dict.extend(vec![0x66; 80]); +/// +/// let bytes = vec![ +/// sum_signature.as_slice(), +/// update_signature.as_slice(), +/// masked_model.as_slice(), +/// local_seed_dict.as_slice(), +/// ] +/// .concat(); +/// +/// let buffer = UpdateBuffer::new(&bytes).unwrap(); +/// assert_eq!(buffer.sum_signature(), sum_signature.as_slice()); +/// assert_eq!(buffer.update_signature(), update_signature.as_slice()); +/// assert_eq!(buffer.masked_model().bytes(), masked_model.as_slice()); +/// assert_eq!(buffer.local_seed_dict().bytes(), local_seed_dict.as_slice()); +/// +/// +/// // This part shows how to write the same message. We'll write into this writer. +/// let mut writer = vec![0xff; 364]; +/// // our buffer contains invalid data, so we need to call UpdateBuffer::new_unchcked. UpdateBuffer::new would return an error. That means we need to be careful when accessing the fields. To avoid panics, we'll see the fields in order. +/// assert!(UpdateBuffer::new(&mut writer).is_err()); +/// let mut buffer = UpdateBuffer::new_unchecked(&mut writer); +/// +/// buffer +/// .sum_signature_mut() +/// .copy_from_slice(sum_signature.as_slice()); +/// buffer +/// .update_signature_mut() +/// .copy_from_slice(update_signature.as_slice()); +/// // it is important to set the length first, otherwise, "value_mut()" would panic +/// buffer +/// .masked_model_mut() +/// .set_length(8); +/// buffer +/// .masked_model_mut() +/// .value_mut() +/// .copy_from_slice(&masked_model[4..]); +/// buffer +/// .local_seed_dict_mut() +/// .set_length(228); +/// buffer +/// .local_seed_dict_mut() +/// .value_mut() +/// .copy_from_slice(&local_seed_dict[4..]); +/// assert_eq!(writer, bytes); +/// ``` +pub struct UpdateBuffer { + inner: T, +} + +impl> UpdateBuffer { + /// Perform bound checks on `bytes` to ensure its fields can be + /// accessed without panicking, and return an `UpdateBuffer`. + pub fn new(bytes: T) -> Result { + let buffer = Self { inner: bytes }; + buffer + .check_buffer_length() + .context("invalid UpdateBuffer")?; + Ok(buffer) + } + + /// Return an `UpdateBuffer` without performing any bound + /// check. This means accessing the various fields may panic if + /// the data is invalid. + pub fn new_unchecked(bytes: T) -> Self { + Self { inner: bytes } + } + + /// Perform bound checks to ensure the fields can be accessed + /// without panicking. + pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + let len = self.inner.as_ref().len(); + // First, check the fixed size portion of the + // header. UPDATE_SIGNATURE_RANGE is the last field + if len < UPDATE_SIGNATURE_RANGE.end { + return Err(anyhow!( + "invalid buffer length: {} < {}", + len, + UPDATE_SIGNATURE_RANGE.end + )); + } + + // Check the length of the length of the masked model field + let _ = LengthValueBuffer::new(&self.inner.as_ref()[self.masked_model_offset()..]) + .context("invalid masked model field length")?; + + // Check the length of the local seed dictionary field + let _ = LengthValueBuffer::new(&self.inner.as_ref()[self.local_seed_dict_offset()..]) + .context("invalid local seed dictionary length")?; + + Ok(()) + } + + /// Get the offset of the masked model field + fn masked_model_offset(&self) -> usize { + UPDATE_SIGNATURE_RANGE.end + } + + /// Get the offset of the local seed dictionary field + fn local_seed_dict_offset(&self) -> usize { + let masked_model = + LengthValueBuffer::new_unchecked(&self.inner.as_ref()[self.masked_model_offset()..]); + self.masked_model_offset() + masked_model.length() as usize + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> UpdateBuffer<&'a T> { + /// Get the sum signature field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn sum_signature(&self) -> &'a [u8] { + &self.inner.as_ref()[SUM_SIGNATURE_RANGE] + } + + /// Get the update signature field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn update_signature(&self) -> &'a [u8] { + &self.inner.as_ref()[UPDATE_SIGNATURE_RANGE] + } + + /// Get the masked model field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn masked_model(&self) -> LengthValueBuffer<&'a [u8]> { + let offset = self.masked_model_offset(); + LengthValueBuffer::new_unchecked(&self.inner.as_ref()[offset..]) + } + + /// Get the local seed dictionary field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn local_seed_dict(&self) -> LengthValueBuffer<&'a [u8]> { + let offset = self.local_seed_dict_offset(); + LengthValueBuffer::new_unchecked(&self.inner.as_ref()[offset..]) + } +} + +impl + AsMut<[u8]>> UpdateBuffer { + /// Get a mutable reference to the sum signature field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn sum_signature_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[SUM_SIGNATURE_RANGE] + } + + /// Get a mutable reference to the update signature field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn update_signature_mut(&mut self) -> &mut [u8] { + &mut self.inner.as_mut()[UPDATE_SIGNATURE_RANGE] + } + + /// Get a mutable reference to the masked model field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn masked_model_mut(&mut self) -> LengthValueBuffer<&mut [u8]> { + let offset = UPDATE_SIGNATURE_RANGE.end; + LengthValueBuffer::new_unchecked(&mut self.inner.as_mut()[offset..]) + } + + /// Get a mutable reference to the local seed dictionary field + /// + /// # Panic + /// + /// This may panic if the underlying buffer does not represent a + /// valid update. If `self.check_buffer_length()` returned + /// `Ok(())` this method is guaranteed not to panic. + pub fn local_seed_dict_mut(&mut self) -> LengthValueBuffer<&mut [u8]> { + let offset = UPDATE_SIGNATURE_RANGE.end + self.masked_model_mut().length() as usize; + LengthValueBuffer::new_unchecked(&mut self.inner.as_mut()[offset..]) + } +} + +#[derive(Debug, Eq, PartialEq, Clone)] +/// High level representation of an update message. These messages are +/// sent by update partipants during the update phase. +pub struct Update { + /// Signature of the round seed and the word "sum", used to + /// determine whether a participant is selected for the sum task + pub sum_signature: ParticipantTaskSignature, + /// Signature of the round seed and the word "update", used to + /// determine whether a participant is selected for the update + /// task + pub update_signature: ParticipantTaskSignature, + /// Model trained by an update participant, masked with randomness + /// derived from the participant seed + pub masked_model: M, + /// A dictionary that contains the seed used to mask + /// `masked_model`, encrypted with the ephemeral public key of + /// each sum participant + pub local_seed_dict: D, +} + +impl ToBytes for Update +where + D: Borrow, + M: Borrow, +{ + fn buffer_length(&self) -> usize { + UPDATE_SIGNATURE_RANGE.end + + self.masked_model.borrow().buffer_length() + + self.local_seed_dict.borrow().buffer_length() + } + + fn to_bytes>(&self, buffer: &mut T) { + let mut writer = UpdateBuffer::new_unchecked(buffer.as_mut()); + self.sum_signature.to_bytes(&mut writer.sum_signature_mut()); + self.update_signature + .to_bytes(&mut writer.update_signature_mut()); + self.masked_model + .borrow() + .to_bytes(&mut writer.masked_model_mut().bytes_mut()); + self.local_seed_dict + .borrow() + .to_bytes(&mut writer.local_seed_dict_mut().bytes_mut()); + } +} + +/// Owned version of a [`Update`] +pub type UpdateOwned = Update; + +impl FromBytes for UpdateOwned { + fn from_bytes>(buffer: &T) -> Result { + let reader = UpdateBuffer::new(buffer.as_ref())?; + Ok(Self { + sum_signature: ParticipantTaskSignature::from_bytes(&reader.sum_signature()) + .context("invalid sum signature")?, + update_signature: ParticipantTaskSignature::from_bytes(&reader.update_signature()) + .context("invalid update signature")?, + masked_model: MaskedModel::from_bytes(&reader.masked_model().bytes()) + .context("invalid masked model")?, + local_seed_dict: LocalSeedDict::from_bytes(&reader.local_seed_dict().bytes()) + .context("invalid local seed dictionary")?, + }) + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::{crypto::ByteObject, mask::MaskedModel, EncrMaskSeed, SumParticipantPublicKey}; + use std::convert::TryFrom; + + fn sum_signature_bytes() -> Vec { + vec![0x33; 64] + } + + fn update_signature_bytes() -> Vec { + vec![0x44; 64] + } + + fn masked_model_bytes() -> Vec { + vec![ + 0x00, 0x00, 0x00, 0x08, // Length 8 + 0x00, 0x01, 0x02, 0x03, // Value: 0, 1, 2, 3 + ] + } + + fn local_seed_dict_bytes() -> Vec { + let mut bytes = vec![]; + + // Length (32+80) * 2 + 4 = 228 + bytes.extend(vec![0x00, 0x00, 0x00, 0xe4]); + + bytes.extend(vec![0x55; SumParticipantPublicKey::LENGTH]); // sum participant pk + bytes.extend(vec![0x66; EncrMaskSeed::BYTES]); // ephemeral pk + + // Second entry + bytes.extend(vec![0x77; SumParticipantPublicKey::LENGTH]); // sum participant pk + bytes.extend(vec![0x88; EncrMaskSeed::BYTES]); // ephemeral pk + + bytes + } + + fn update_bytes() -> Vec { + let mut bytes = sum_signature_bytes(); + bytes.extend(update_signature_bytes()); + bytes.extend(masked_model_bytes()); + bytes.extend(local_seed_dict_bytes()); + bytes + } + + fn update() -> UpdateOwned { + let sum_signature = + ParticipantTaskSignature::from_slice(&sum_signature_bytes()[..]).unwrap(); + let update_signature = + ParticipantTaskSignature::from_slice(&update_signature_bytes()[..]).unwrap(); + let masked_model = MaskedModel::from(&[0, 1, 2, 3][..]); + + let mut local_seed_dict = LocalSeedDict::new(); + + local_seed_dict.insert( + SumParticipantPublicKey::from_slice(vec![0x55; 32].as_slice()).unwrap(), + EncrMaskSeed::try_from(vec![0x66; EncrMaskSeed::BYTES]).unwrap(), + ); + local_seed_dict.insert( + SumParticipantPublicKey::from_slice(vec![0x77; 32].as_slice()).unwrap(), + EncrMaskSeed::try_from(vec![0x88; EncrMaskSeed::BYTES]).unwrap(), + ); + + UpdateOwned { + sum_signature, + update_signature, + masked_model, + local_seed_dict, + } + } + + #[test] + fn buffer_read() { + let bytes = update_bytes(); + let buffer = UpdateBuffer::new(&bytes).unwrap(); + assert_eq!(buffer.sum_signature(), sum_signature_bytes().as_slice()); + assert_eq!( + buffer.update_signature(), + update_signature_bytes().as_slice() + ); + assert_eq!(buffer.masked_model().bytes(), &masked_model_bytes()[..]); + assert_eq!( + buffer.local_seed_dict().bytes(), + &local_seed_dict_bytes()[..] + ); + } + + #[test] + fn decode_invalid_seed_dict() { + let mut invalid = local_seed_dict_bytes(); + // This truncates the last entry of the seed dictionary + invalid[3] = 0xe3; + let mut bytes = vec![]; + bytes.extend(sum_signature_bytes()); + bytes.extend(update_signature_bytes()); + bytes.extend(masked_model_bytes()); + bytes.extend(invalid); + + let e = UpdateOwned::from_bytes(&bytes).unwrap_err(); + let cause = e.source().unwrap().to_string(); + assert_eq!( + cause, + "invalid local seed dictionary: trailing bytes".to_string() + ); + } + + #[test] + fn decode() { + let bytes = update_bytes(); + let parsed = UpdateOwned::from_bytes(&bytes).unwrap(); + assert_eq!(parsed, update()); + } + + #[test] + fn encode() { + let expected = update_bytes(); + let update = update(); + assert_eq!(update.buffer_length(), expected.len()); + let mut buf = vec![0xff; update.buffer_length()]; + update.to_bytes(&mut buf); + // The order in which the hashmap is serialized is not + // guaranteed, but we chose our key/values such that they are + // sorted. + // + // First compute the offset at which the local seed dict value + // starts: two signature (64 bytes), the masked model (8 + // bytes), the length field (4 bytes) + let offset = 64 * 2 + 8 + 4; + // Sort the end of the buffer + (&mut buf[offset..]).sort(); + assert_eq!(buf, expected); + } +} diff --git a/rust/src/message/sum.rs b/rust/src/message/sum.rs deleted file mode 100644 index 88c476528..000000000 --- a/rust/src/message/sum.rs +++ /dev/null @@ -1,460 +0,0 @@ -use std::{ - borrow::Borrow, - convert::{TryFrom, TryInto}, - ops::Range, -}; - -use super::{MessageBuffer, Tag, LEN_BYTES, PK_BYTES}; - -use crate::{ - certificate::Certificate, - crypto::{ByteObject, Signature}, - CoordinatorPublicKey, - CoordinatorSecretKey, - ParticipantTaskSignature, - PetError, - PublicSigningKey, - SumParticipantEphemeralPublicKey, - SumParticipantPublicKey, - SumParticipantSecretKey, -}; - -#[derive(Clone, Debug)] -/// Access to sum message buffer fields. -struct SumMessageBuffer { - bytes: B, - certificate_range: Range, -} - -impl SumMessageBuffer> { - /// Create an empty sum message buffer. - fn new(certificate_len: usize) -> Self { - let bytes = [ - vec![0_u8; Self::EPHM_PK_RANGE.end], - certificate_len.to_le_bytes().to_vec(), - vec![0_u8; certificate_len], - ] - .concat(); - let certificate_range = - Self::CERTIFICATE_LEN_RANGE.end..Self::CERTIFICATE_LEN_RANGE.end + certificate_len; - Self { - bytes, - certificate_range, - } - } -} - -impl TryFrom> for SumMessageBuffer> { - type Error = PetError; - - /// Create a sum message buffer from `bytes`. Fails if the length of the input is invalid. - fn try_from(bytes: Vec) -> Result { - let mut buffer = Self { - bytes, - certificate_range: 0..0, - }; - if buffer.len() >= Self::CERTIFICATE_LEN_RANGE.end { - // safe unwrap: length of slice is guaranteed by constants - buffer.certificate_range = Self::CERTIFICATE_LEN_RANGE.end - ..Self::CERTIFICATE_LEN_RANGE.end - + usize::from_le_bytes(buffer.certificate_len().try_into().unwrap()); - } else { - return Err(PetError::InvalidMessage); - } - if buffer.len() == buffer.certificate_range.end { - Ok(buffer) - } else { - Err(PetError::InvalidMessage) - } - } -} - -impl + AsMut<[u8]>> MessageBuffer for SumMessageBuffer { - /// Get a reference to the message buffer. - fn bytes(&'_ self) -> &'_ [u8] { - self.bytes.as_ref() - } - - /// Get a mutable reference to the message buffer. - fn bytes_mut(&mut self) -> &mut [u8] { - self.bytes.as_mut() - } -} - -impl + AsMut<[u8]>> SumMessageBuffer { - /// Get the range of the public ephemeral key field. - const EPHM_PK_RANGE: Range = - Self::SUM_SIGNATURE_RANGE.end..Self::SUM_SIGNATURE_RANGE.end + PK_BYTES; - - /// Get the range of the certificate length field. - const CERTIFICATE_LEN_RANGE: Range = - Self::EPHM_PK_RANGE.end..Self::EPHM_PK_RANGE.end + LEN_BYTES; - - /// Get a reference to the public ephemeral key field. - fn ephm_pk(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::EPHM_PK_RANGE] - } - - /// Get a mutable reference to the public ephemeral key field. - fn ephm_pk_mut(&mut self) -> &mut [u8] { - &mut self.bytes_mut()[Self::EPHM_PK_RANGE] - } - - /// Get a reference to the certificate length field. - fn certificate_len(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::CERTIFICATE_LEN_RANGE] - } - - /// Get a reference to the certificate field. - fn certificate(&'_ self) -> &'_ [u8] { - &self.bytes()[self.certificate_range.clone()] - } - - /// Get a mutable reference to the certificate field. - fn certificate_mut(&mut self) -> &mut [u8] { - let range = self.certificate_range.clone(); - &mut self.bytes_mut()[range] - } -} - -#[derive(Clone, Debug, PartialEq)] -/// Encryption and decryption of sum messages. -pub struct SumMessage -where - K: Borrow, - S: Borrow, - E: Borrow, - C: Borrow, -{ - pk: K, - sum_signature: S, - ephm_pk: E, - certificate: C, -} - -impl SumMessage -where - K: Borrow, - S: Borrow, - E: Borrow, - C: Borrow, -{ - /// Create a sum message from its parts. - pub fn from_parts(pk: K, sum_signature: S, ephm_pk: E, certificate: C) -> Self { - Self { - pk, - sum_signature, - ephm_pk, - certificate, - } - } - - /// Serialize the sum message into a buffer. - fn serialize + AsMut<[u8]>>( - &self, - buffer: &mut SumMessageBuffer, - pk: &CoordinatorPublicKey, - ) { - buffer.tag_mut().copy_from_slice([Tag::Sum as u8].as_ref()); - buffer.coord_pk_mut().copy_from_slice(pk.as_slice()); - buffer - .part_pk_mut() - .copy_from_slice(self.pk.borrow().as_slice()); - buffer - .sum_signature_mut() - .copy_from_slice(self.sum_signature.borrow().as_slice()); - buffer - .ephm_pk_mut() - .copy_from_slice(self.ephm_pk.borrow().as_slice()); - buffer - .certificate_mut() - .copy_from_slice(self.certificate.borrow().as_ref()); - } - - /// Sign and encrypt the sum message. - pub fn seal(&self, sk: &SumParticipantSecretKey, pk: &CoordinatorPublicKey) -> Vec { - let mut buffer = SumMessageBuffer::new(self.certificate.borrow().len()); - self.serialize(&mut buffer, pk); - let signature = sk.sign_detached(buffer.message()); - buffer.signature_mut().copy_from_slice(signature.as_slice()); - pk.encrypt(buffer.bytes()) - } -} - -impl - SumMessage< - SumParticipantPublicKey, - ParticipantTaskSignature, - SumParticipantEphemeralPublicKey, - Certificate, - > -{ - /// Deserialize a sum message from a buffer. Fails if the length of a part is invalid. - fn deserialize(buffer: SumMessageBuffer>) -> Result { - let pk = SumParticipantPublicKey::from_slice(buffer.part_pk()) - .ok_or(PetError::InvalidMessage)?; - let sum_signature = - Signature::from_slice(buffer.sum_signature()).ok_or(PetError::InvalidMessage)?; - let ephm_pk = SumParticipantEphemeralPublicKey::from_slice(buffer.ephm_pk()) - .ok_or(PetError::InvalidMessage)?; - let certificate = Certificate::deserialize(buffer.certificate())?; - Ok(Self { - pk, - sum_signature, - ephm_pk, - certificate, - }) - } - - /// Decrypt and verify a sum message. Fails if decryption or validation fails. - pub fn open( - bytes: &[u8], - pk: &CoordinatorPublicKey, - sk: &CoordinatorSecretKey, - ) -> Result { - let buffer = - SumMessageBuffer::try_from(sk.decrypt(bytes, pk).or(Err(PetError::InvalidMessage))?)?; - if buffer.tag() == [Tag::Sum as u8] - && buffer.coord_pk() == pk.as_slice() - && PublicSigningKey::from_slice(buffer.part_pk()) - .ok_or(PetError::InvalidMessage)? - .verify_detached( - &Signature::from_slice(buffer.signature()).ok_or(PetError::InvalidMessage)?, - buffer.message(), - ) - { - Ok(Self::deserialize(buffer)?) - } else { - Err(PetError::InvalidMessage) - } - } - - derive_struct_fields!( - pk, SumParticipantPublicKey; - sum_signature, ParticipantTaskSignature; - certificate, Certificate; - ephm_pk, SumParticipantEphemeralPublicKey; - ); -} - -#[cfg(test)] -mod tests { - use sodiumoxide::randombytes::randombytes; - - use super::*; - use crate::{ - crypto::{generate_encrypt_key_pair, generate_signing_key_pair}, - message::{SIGNATURE_BYTES, TAG_BYTES}, - }; - - fn auxiliary_bytes() -> Vec { - [ - randombytes(225), - (32 as usize).to_le_bytes().to_vec(), - vec![0_u8; 32], - ] - .concat() - } - - type MB = SumMessageBuffer>; - - #[test] - fn test_summessagebuffer_ranges() { - assert_eq!(MB::SIGNATURE_RANGE, ..SIGNATURE_BYTES); - assert_eq!(MB::MESSAGE_RANGE, SIGNATURE_BYTES..); - assert_eq!(MB::TAG_RANGE, 64..64 + TAG_BYTES); - assert_eq!(MB::COORD_PK_RANGE, 65..65 + PK_BYTES); - assert_eq!(MB::PART_PK_RANGE, 97..97 + PK_BYTES); - assert_eq!(MB::SUM_SIGNATURE_RANGE, 129..129 + SIGNATURE_BYTES); - assert_eq!(MB::EPHM_PK_RANGE, 193..193 + PK_BYTES); - assert_eq!(MB::CERTIFICATE_LEN_RANGE, 225..225 + LEN_BYTES); - assert_eq!( - SumMessageBuffer::new(32).certificate_range, - 225 + LEN_BYTES..225 + LEN_BYTES + 32, - ); - } - - #[test] - fn test_summessagebuffer_fields() { - // new - assert_eq!( - SumMessageBuffer::new(32).bytes, - [ - vec![0_u8; 225], - (32 as usize).to_le_bytes().to_vec(), - vec![0_u8; 32], - ] - .concat(), - ); - - // try from - let mut bytes = auxiliary_bytes(); - let mut buffer = SumMessageBuffer::try_from(bytes.clone()).unwrap(); - assert_eq!(buffer.bytes, bytes); - assert_eq!( - SumMessageBuffer::try_from(vec![0_u8; 0]).unwrap_err(), - PetError::InvalidMessage, - ); - - // length - assert_eq!(buffer.len(), 257 + LEN_BYTES); - - // signature - assert_eq!(buffer.signature(), &bytes[MB::SIGNATURE_RANGE]); - assert_eq!(buffer.signature_mut(), &mut bytes[MB::SIGNATURE_RANGE]); - - // message - assert_eq!(buffer.message(), &bytes[MB::MESSAGE_RANGE]); - - // tag - assert_eq!(buffer.tag(), &bytes[MB::TAG_RANGE]); - assert_eq!(buffer.tag_mut(), &mut bytes[MB::TAG_RANGE]); - - // coordinator pk - assert_eq!(buffer.coord_pk(), &bytes[MB::COORD_PK_RANGE]); - assert_eq!(buffer.coord_pk_mut(), &mut bytes[MB::COORD_PK_RANGE]); - - // participant pk - assert_eq!(buffer.part_pk(), &bytes[MB::PART_PK_RANGE]); - assert_eq!(buffer.part_pk_mut(), &mut bytes[MB::PART_PK_RANGE]); - - // sum signature - assert_eq!(buffer.sum_signature(), &bytes[MB::SUM_SIGNATURE_RANGE]); - assert_eq!( - buffer.sum_signature_mut(), - &mut bytes[MB::SUM_SIGNATURE_RANGE], - ); - - // ephm pk - assert_eq!(buffer.ephm_pk(), &bytes[MB::EPHM_PK_RANGE]); - assert_eq!(buffer.ephm_pk_mut(), &mut bytes[MB::EPHM_PK_RANGE]); - - // certificate - assert_eq!(buffer.certificate_len(), &bytes[MB::CERTIFICATE_LEN_RANGE]); - let range = buffer.certificate_range.clone(); - assert_eq!(buffer.certificate(), &bytes[range.clone()]); - assert_eq!(buffer.certificate_mut(), &mut bytes[range]); - } - - #[test] - fn test_summessage_serialize() { - // from parts - let pk = &SumParticipantPublicKey::from_slice_unchecked(randombytes(32).as_slice()); - let sum_signature = &Signature::from_slice_unchecked(randombytes(64).as_slice()); - let ephm_pk = - &SumParticipantEphemeralPublicKey::from_slice_unchecked(randombytes(32).as_slice()); - let certificate = &Certificate::zeroed(); - let msg = SumMessage::from_parts(pk, sum_signature, ephm_pk, certificate); - assert_eq!( - msg.pk as *const SumParticipantPublicKey, - pk as *const SumParticipantPublicKey, - ); - assert_eq!( - msg.sum_signature as *const Signature, - sum_signature as *const Signature, - ); - assert_eq!( - msg.ephm_pk as *const SumParticipantEphemeralPublicKey, - ephm_pk as *const SumParticipantEphemeralPublicKey, - ); - assert_eq!( - msg.certificate as *const Certificate, - certificate as *const Certificate, - ); - - // serialize - let mut buffer = SumMessageBuffer::new(32); - let coord_pk = CoordinatorPublicKey::from_slice_unchecked(randombytes(32).as_slice()); - msg.serialize(&mut buffer, &coord_pk); - assert_eq!(buffer.tag(), [Tag::Sum as u8].as_ref()); - assert_eq!(buffer.coord_pk(), coord_pk.as_slice()); - assert_eq!(buffer.part_pk(), pk.as_slice()); - assert_eq!(buffer.sum_signature(), sum_signature.as_slice()); - assert_eq!(buffer.ephm_pk(), ephm_pk.as_slice()); - assert_eq!( - buffer.certificate_len(), - certificate.len().to_le_bytes().as_ref(), - ); - assert_eq!(buffer.certificate(), certificate.as_slice()); - } - - #[test] - fn test_summessage_deserialize() { - // deserialize - let bytes = auxiliary_bytes(); - let buffer = SumMessageBuffer::try_from(bytes.clone()).unwrap(); - let msg = SumMessage::deserialize(buffer.clone()).unwrap(); - assert_eq!( - msg.pk(), - &SumParticipantPublicKey::from_slice_unchecked(&bytes[MB::PART_PK_RANGE]), - ); - assert_eq!( - msg.sum_signature(), - &Signature::from_slice_unchecked(&bytes[MB::SUM_SIGNATURE_RANGE]), - ); - assert_eq!( - msg.ephm_pk(), - &SumParticipantEphemeralPublicKey::from_slice_unchecked(&bytes[MB::EPHM_PK_RANGE]), - ); - assert_eq!( - msg.certificate(), - &Certificate::deserialize(&bytes[buffer.certificate_range.clone()]).unwrap(), - ); - } - - #[test] - fn test_summessage() { - // seal - let (pk, sk) = generate_signing_key_pair(); - let sum_signature = Signature::from_slice_unchecked(randombytes(64).as_slice()); - let ephm_pk = - SumParticipantEphemeralPublicKey::from_slice_unchecked(randombytes(32).as_slice()); - let certificate = Certificate::zeroed(); - let (coord_pk, coord_sk) = generate_encrypt_key_pair(); - let bytes = SumMessage::from_parts(&pk, &sum_signature, &ephm_pk, &certificate) - .seal(&sk, &coord_pk); - - // open - let msg = SumMessage::open(&bytes, &coord_pk, &coord_sk).unwrap(); - assert_eq!(msg.pk(), &pk); - assert_eq!(msg.sum_signature(), &sum_signature); - assert_eq!(msg.ephm_pk(), &ephm_pk); - assert_eq!(msg.certificate(), &certificate); - - // wrong signature - let bytes = auxiliary_bytes(); - let mut buffer = SumMessageBuffer::try_from(bytes).unwrap(); - let msg = SumMessage::from_parts(&pk, &sum_signature, &ephm_pk, &certificate); - msg.serialize(&mut buffer, &coord_pk); - let bytes = coord_pk.encrypt(buffer.bytes()); - assert_eq!( - SumMessage::open(&bytes, &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - - // wrong receiver - msg.serialize( - &mut buffer, - &CoordinatorPublicKey::from_slice_unchecked(randombytes(32).as_slice()), - ); - let bytes = coord_pk.encrypt(buffer.bytes()); - assert_eq!( - SumMessage::open(&bytes, &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - - // wrong tag - buffer.tag_mut().copy_from_slice([Tag::None as u8].as_ref()); - let bytes = coord_pk.encrypt(buffer.bytes()); - assert_eq!( - SumMessage::open(&bytes, &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - - // wrong length - assert_eq!( - SumMessage::open([0_u8; 0].as_ref(), &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - } -} diff --git a/rust/src/message/sum2.rs b/rust/src/message/sum2.rs deleted file mode 100644 index fd6989904..000000000 --- a/rust/src/message/sum2.rs +++ /dev/null @@ -1,489 +0,0 @@ -use std::{ - borrow::Borrow, - convert::{TryFrom, TryInto}, - ops::Range, -}; - -use super::{MessageBuffer, Tag, LEN_BYTES}; -use crate::{ - certificate::Certificate, - crypto::{ByteObject, Signature}, - mask::{Integers, Mask}, - CoordinatorPublicKey, - CoordinatorSecretKey, - ParticipantTaskSignature, - PetError, - SumParticipantPublicKey, - SumParticipantSecretKey, -}; - -#[derive(Clone, Debug)] -/// Access to sum2 message buffer fields. -struct Sum2MessageBuffer { - bytes: B, - certificate_range: Range, - mask_range: Range, -} - -impl Sum2MessageBuffer> { - /// Create an empty sum2 message buffer. - fn new(certificate_len: usize, mask_len: usize) -> Self { - let bytes = [ - vec![0_u8; Self::SUM_SIGNATURE_RANGE.end], - certificate_len.to_le_bytes().to_vec(), - mask_len.to_le_bytes().to_vec(), - vec![0_u8; certificate_len + mask_len], - ] - .concat(); - let certificate_range = - Self::MASK_LEN_RANGE.end..Self::MASK_LEN_RANGE.end + certificate_len; - let mask_range = certificate_range.end..certificate_range.end + mask_len; - Self { - bytes, - certificate_range, - mask_range, - } - } -} - -impl TryFrom> for Sum2MessageBuffer> { - type Error = PetError; - - /// Create a sum2 message buffer from `bytes`. Fails if the length of the input is invalid. - fn try_from(bytes: Vec) -> Result { - let mut buffer = Self { - bytes, - certificate_range: 0..0, - mask_range: 0..0, - }; - if buffer.len() >= Self::MASK_LEN_RANGE.end { - // safe unwraps: lengths of slices are guaranteed by constants - buffer.certificate_range = Self::MASK_LEN_RANGE.end - ..Self::MASK_LEN_RANGE.end - + usize::from_le_bytes(buffer.certificate_len().try_into().unwrap()); - buffer.mask_range = buffer.certificate_range.end - ..buffer.certificate_range.end - + usize::from_le_bytes(buffer.mask_len().try_into().unwrap()); - } else { - return Err(PetError::InvalidMessage); - } - if buffer.len() == buffer.mask_range.end { - Ok(buffer) - } else { - Err(PetError::InvalidMessage) - } - } -} - -impl + AsMut<[u8]>> MessageBuffer for Sum2MessageBuffer { - /// Get a reference to the message buffer. - fn bytes(&'_ self) -> &'_ [u8] { - self.bytes.as_ref() - } - - /// Get a mutable reference to the message buffer. - fn bytes_mut(&mut self) -> &mut [u8] { - self.bytes.as_mut() - } -} - -impl + AsMut<[u8]>> Sum2MessageBuffer { - /// Get the range of the certificate length field. - const CERTIFICATE_LEN_RANGE: Range = - Self::SUM_SIGNATURE_RANGE.end..Self::SUM_SIGNATURE_RANGE.end + LEN_BYTES; - - /// Get the range of the masked model length field. - const MASK_LEN_RANGE: Range = - Self::CERTIFICATE_LEN_RANGE.end..Self::CERTIFICATE_LEN_RANGE.end + LEN_BYTES; - - /// Get a reference to the certificate length field. - fn certificate_len(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::CERTIFICATE_LEN_RANGE] - } - - /// Get a reference to the certificate field. - fn certificate(&'_ self) -> &'_ [u8] { - &self.bytes()[self.certificate_range.clone()] - } - - /// Get a mutable reference to the certificate field. - fn certificate_mut(&mut self) -> &mut [u8] { - let range = self.certificate_range.clone(); - &mut self.bytes_mut()[range] - } - - /// Get a reference to the mask length field. - fn mask_len(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::MASK_LEN_RANGE] - } - - /// Get a reference to the mask field. - fn mask(&'_ self) -> &'_ [u8] { - &self.bytes()[self.mask_range.clone()] - } - - /// Get a mutable reference to the mask field. - fn mask_mut(&mut self) -> &mut [u8] { - let range = self.mask_range.clone(); - &mut self.bytes_mut()[range] - } -} - -#[derive(Clone, Debug, PartialEq)] -/// Encryption and decryption of sum2 messages. -pub struct Sum2Message -where - K: Borrow, - S: Borrow, - C: Borrow, - M: Borrow, -{ - pk: K, - sum_signature: S, - certificate: C, - mask: M, -} - -impl Sum2Message -where - K: Borrow, - S: Borrow, - C: Borrow, - M: Borrow, -{ - /// Create a sum2 message from its parts. - pub fn from_parts(pk: K, sum_signature: S, certificate: C, mask: M) -> Self { - Self { - pk, - sum_signature, - certificate, - mask, - } - } - - /// Serialize the sum2 message into a buffer. - fn serialize + AsMut<[u8]>>( - &self, - buffer: &mut Sum2MessageBuffer, - pk: &CoordinatorPublicKey, - ) { - buffer.tag_mut().copy_from_slice([Tag::Sum2 as u8].as_ref()); - buffer - .coord_pk_mut() - .copy_from_slice(pk.borrow().as_slice()); - buffer - .part_pk_mut() - .copy_from_slice(self.pk.borrow().as_slice()); - buffer - .sum_signature_mut() - .copy_from_slice(self.sum_signature.borrow().as_slice()); - buffer - .certificate_mut() - .copy_from_slice(self.certificate.borrow().as_ref()); - buffer - .mask_mut() - .copy_from_slice(self.mask.borrow().serialize().as_slice()); - } - - /// Sign and encrypt the sum2message. - pub fn seal(&self, sk: &SumParticipantSecretKey, pk: &CoordinatorPublicKey) -> Vec { - let mut buffer = - Sum2MessageBuffer::new(self.certificate.borrow().len(), self.mask.borrow().len()); - self.serialize(&mut buffer, pk); - let signature = sk.sign_detached(buffer.message()); - buffer.signature_mut().copy_from_slice(signature.as_slice()); - pk.encrypt(buffer.bytes()) - } -} - -impl Sum2Message { - /// Deserialize a sum2 message from a buffer. Fails if the length of a part is invalid. - fn deserialize(buffer: Sum2MessageBuffer>) -> Result { - let pk = SumParticipantPublicKey::from_slice(buffer.part_pk()) - .ok_or(PetError::InvalidMessage)?; - let sum_signature = - Signature::from_slice(buffer.sum_signature()).ok_or(PetError::InvalidMessage)?; - let certificate = Certificate::deserialize(buffer.certificate())?; - let mask = Mask::deserialize(buffer.mask())?; - Ok(Self { - pk, - sum_signature, - certificate, - mask, - }) - } - - /// Decrypt and verify a sum2 message. Fails if decryption or validation fails. - pub fn open( - bytes: &[u8], - pk: &CoordinatorPublicKey, - sk: &CoordinatorSecretKey, - ) -> Result { - let buffer = - Sum2MessageBuffer::try_from(sk.decrypt(bytes, pk).or(Err(PetError::InvalidMessage))?)?; - if buffer.tag() == [Tag::Sum2 as u8] - && buffer.coord_pk() == pk.as_slice() - && SumParticipantPublicKey::from_slice(buffer.part_pk()) - .ok_or(PetError::InvalidMessage)? - .verify_detached( - &Signature::from_slice(buffer.signature()).ok_or(PetError::InvalidMessage)?, - buffer.message(), - ) - { - Ok(Self::deserialize(buffer)?) - } else { - Err(PetError::InvalidMessage) - } - } - - derive_struct_fields!( - pk, SumParticipantPublicKey; - sum_signature, ParticipantTaskSignature; - certificate, Certificate; - mask, Mask; - ); -} - -#[cfg(test)] -mod tests { - use sodiumoxide::randombytes::randombytes; - - use super::*; - use crate::{ - crypto::{generate_encrypt_key_pair, generate_signing_key_pair}, - mask::{ - config::{BoundType, DataType, GroupType, MaskConfigs, ModelType}, - seed::MaskSeed, - }, - message::{PK_BYTES, SIGNATURE_BYTES, TAG_BYTES}, - }; - - type MB = Sum2MessageBuffer>; - - fn auxiliary_bytes() -> Vec { - let mask = auxiliary_mask(); - [ - randombytes(193), - (32 as usize).to_le_bytes().to_vec(), - mask.len().to_le_bytes().to_vec(), - randombytes(32), - mask.serialize(), - ] - .concat() - } - - fn auxiliary_mask() -> Mask { - let config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); - MaskSeed::generate().derive_mask(10, &config) - } - - #[test] - fn test_sum2messagebuffer_ranges() { - assert_eq!(MB::SIGNATURE_RANGE, ..SIGNATURE_BYTES); - assert_eq!(MB::MESSAGE_RANGE, SIGNATURE_BYTES..); - assert_eq!(MB::TAG_RANGE, 64..64 + TAG_BYTES); - assert_eq!(MB::COORD_PK_RANGE, 65..65 + PK_BYTES); - assert_eq!(MB::PART_PK_RANGE, 97..97 + PK_BYTES); - assert_eq!(MB::SUM_SIGNATURE_RANGE, 129..129 + SIGNATURE_BYTES); - assert_eq!(MB::CERTIFICATE_LEN_RANGE, 193..193 + LEN_BYTES); - assert_eq!(MB::MASK_LEN_RANGE, 193 + LEN_BYTES..193 + 2 * LEN_BYTES); - let buffer = Sum2MessageBuffer::new(32, 32); - assert_eq!( - buffer.certificate_range, - 193 + 2 * LEN_BYTES..193 + 2 * LEN_BYTES + 32, - ); - assert_eq!( - buffer.mask_range, - 193 + 2 * LEN_BYTES + 32..193 + 2 * LEN_BYTES + 32 + 32, - ); - } - - #[test] - fn test_sum2messagebuffer_fields() { - // new - assert_eq!( - Sum2MessageBuffer::new(32, 32).bytes, - [ - vec![0_u8; 193], - (32 as usize).to_le_bytes().to_vec(), - (32 as usize).to_le_bytes().to_vec(), - vec![0_u8; 64], - ] - .concat(), - ); - - // try from - let mut bytes = auxiliary_bytes(); - let mut buffer = Sum2MessageBuffer::try_from(bytes.clone()).unwrap(); - assert_eq!(buffer.bytes, bytes); - assert_eq!( - Sum2MessageBuffer::try_from(vec![0_u8; 0]).unwrap_err(), - PetError::InvalidMessage, - ); - - // length - assert_eq!(buffer.len(), 289 + 2 * LEN_BYTES); - - // signature - assert_eq!(buffer.signature(), &bytes[MB::SIGNATURE_RANGE]); - assert_eq!(buffer.signature_mut(), &mut bytes[MB::SIGNATURE_RANGE]); - - // message - assert_eq!(buffer.message(), &bytes[MB::MESSAGE_RANGE]); - - // tag - assert_eq!(buffer.tag(), &bytes[MB::TAG_RANGE]); - assert_eq!(buffer.tag_mut(), &mut bytes[MB::TAG_RANGE]); - - // coordinator pk - assert_eq!(buffer.coord_pk(), &bytes[MB::COORD_PK_RANGE]); - assert_eq!(buffer.coord_pk_mut(), &mut bytes[MB::COORD_PK_RANGE]); - - // participant pk - assert_eq!(buffer.part_pk(), &bytes[MB::PART_PK_RANGE]); - assert_eq!(buffer.part_pk_mut(), &mut bytes[MB::PART_PK_RANGE]); - - // sum signature - assert_eq!(buffer.sum_signature(), &bytes[MB::SUM_SIGNATURE_RANGE]); - assert_eq!( - buffer.sum_signature_mut(), - &mut bytes[MB::SUM_SIGNATURE_RANGE], - ); - - // certificate - assert_eq!(buffer.certificate_len(), &bytes[MB::CERTIFICATE_LEN_RANGE]); - let range = buffer.certificate_range.clone(); - assert_eq!(buffer.certificate(), &bytes[range.clone()]); - assert_eq!(buffer.certificate_mut(), &mut bytes[range]); - - // mask - assert_eq!(buffer.mask_len(), &bytes[MB::MASK_LEN_RANGE]); - let range = buffer.mask_range.clone(); - assert_eq!(buffer.mask(), &bytes[range.clone()]); - assert_eq!(buffer.mask_mut(), &mut bytes[range]); - } - - #[test] - fn test_sum2message_serialize() { - // from parts - let pk = &SumParticipantPublicKey::from_slice_unchecked(randombytes(32).as_slice()); - let sum_signature = &Signature::from_slice_unchecked(randombytes(64).as_slice()); - let certificate = &Certificate::zeroed(); - let mask = &auxiliary_mask(); - let msg = Sum2Message::from_parts(pk, sum_signature, certificate, mask); - assert_eq!( - msg.pk as *const SumParticipantPublicKey, - pk as *const SumParticipantPublicKey, - ); - assert_eq!( - msg.sum_signature as *const Signature, - sum_signature as *const Signature, - ); - assert_eq!( - msg.certificate as *const Certificate, - certificate as *const Certificate, - ); - assert_eq!(msg.mask as *const Mask, mask as *const Mask); - - // serialize - let mut buffer = Sum2MessageBuffer::new(32, mask.len()); - let coord_pk = CoordinatorPublicKey::from_slice_unchecked(randombytes(32).as_slice()); - msg.serialize(&mut buffer, &coord_pk); - assert_eq!(buffer.tag(), [Tag::Sum2 as u8].as_ref()); - assert_eq!(buffer.coord_pk(), coord_pk.as_slice()); - assert_eq!(buffer.part_pk(), pk.as_slice()); - assert_eq!(buffer.sum_signature(), sum_signature.as_slice()); - assert_eq!( - buffer.certificate_len(), - certificate.len().to_le_bytes().as_ref(), - ); - assert_eq!(buffer.certificate(), certificate.as_slice()); - assert_eq!(buffer.mask_len(), mask.len().to_le_bytes().as_ref()); - assert_eq!(buffer.mask(), mask.serialize().as_slice()); - } - - #[test] - fn test_sum2message_deserialize() { - // deserialize - let bytes = auxiliary_bytes(); - let buffer = Sum2MessageBuffer::try_from(bytes.clone()).unwrap(); - let msg = Sum2Message::deserialize(buffer.clone()).unwrap(); - assert_eq!( - msg.pk(), - &SumParticipantPublicKey::from_slice_unchecked(&bytes[MB::PART_PK_RANGE]), - ); - assert_eq!( - msg.sum_signature(), - &Signature::from_slice_unchecked(&bytes[MB::SUM_SIGNATURE_RANGE]), - ); - assert_eq!( - msg.certificate(), - &Certificate::deserialize(&bytes[buffer.certificate_range.clone()]).unwrap(), - ); - assert_eq!( - msg.mask(), - &Mask::deserialize(&bytes[buffer.mask_range.clone()]).unwrap(), - ); - } - - #[test] - fn test_sum2message() { - // seal - let (pk, sk) = generate_signing_key_pair(); - let sum_signature = Signature::from_slice_unchecked(randombytes(64).as_slice()); - let certificate = Certificate::zeroed(); - let mask = auxiliary_mask(); - let (coord_pk, coord_sk) = generate_encrypt_key_pair(); - let bytes = - Sum2Message::from_parts(&pk, &sum_signature, &certificate, &mask).seal(&sk, &coord_pk); - - // open - let msg = Sum2Message::open(&bytes, &coord_pk, &coord_sk).unwrap(); - assert_eq!(msg.pk(), &pk); - assert_eq!(msg.sum_signature(), &sum_signature); - assert_eq!(msg.certificate(), &certificate); - assert_eq!(msg.mask(), &mask); - - // wrong signature - let bytes = auxiliary_bytes(); - let mut buffer = Sum2MessageBuffer::try_from(bytes).unwrap(); - let msg = Sum2Message::from_parts(&pk, &sum_signature, &certificate, &mask); - msg.serialize(&mut buffer, &coord_pk); - let bytes = coord_pk.encrypt(buffer.bytes()); - assert_eq!( - Sum2Message::open(&bytes, &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - - // wrong receiver - msg.serialize( - &mut buffer, - &CoordinatorPublicKey::from_slice_unchecked(randombytes(32).as_slice()), - ); - let bytes = coord_pk.encrypt(buffer.bytes()); - assert_eq!( - Sum2Message::open(&bytes, &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - - // wrong tag - buffer.tag_mut().copy_from_slice([Tag::None as u8].as_ref()); - let bytes = coord_pk.encrypt(buffer.bytes()); - assert_eq!( - Sum2Message::open(&bytes, &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - - // wrong length - assert_eq!( - Sum2Message::open([0_u8; 0].as_ref(), &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - } -} diff --git a/rust/src/message/traits.rs b/rust/src/message/traits.rs new file mode 100644 index 000000000..e5f314832 --- /dev/null +++ b/rust/src/message/traits.rs @@ -0,0 +1,405 @@ +use super::DecodeError; +use crate::{crypto::ByteObject, EncrMaskSeed, LocalSeedDict, SumParticipantPublicKey}; +use anyhow::{anyhow, Context}; +use std::{ + convert::{TryFrom, TryInto}, + io::{Cursor, Write}, + ops::Range, +}; + +/// `ToBytes` is implemented by types that can be serialized +pub trait ToBytes { + /// Length of the buffer for encoding the type + fn buffer_length(&self) -> usize; + /// Serialize the type in the given buffer. + /// + /// # Panic + /// + /// This method may panic if the given buffer is too small. Thus, + /// [`buffer_length`] must be called prior to calling `to_bytes`, + /// and a large enough buffer must be provided. + fn to_bytes>(&self, buffer: &mut T); +} + +/// `FromBytes` is the counterpart of `ToBytes`. It is implemented by +/// types that can be deserialized. +pub trait FromBytes: Sized { + /// Deserialize the type from the given buffer + fn from_bytes>(buffer: &T) -> Result; +} + +impl FromBytes for T +where + T: ByteObject, +{ + fn from_bytes>(buffer: &U) -> Result { + Self::from_slice(buffer.as_ref()).ok_or(anyhow!("failed to deserialize byte object")) + } +} + +impl ToBytes for T +where + T: ByteObject, +{ + fn buffer_length(&self) -> usize { + self.as_slice().len() + } + + fn to_bytes>(&self, buffer: &mut U) { + buffer.as_mut().copy_from_slice(self.as_slice()) + } +} + +/// Helper for encoding and decoding Length-Value (LV) fields. +/// +/// Note that the 4 bytes length field gives the length of the *total* +/// Length-Value field, _i.e._ the length of the value, plus the 4 +/// extra bytes of the length field itself. +/// +/// # Examples +/// +/// Decoding an LV field: +/// +/// ```rust +/// # use xain_fl::message::LengthValueBuffer; +/// let bytes = vec![ +/// 0x00, 0x00, 0x00, 0x05, // Length = 5 +/// 0xff, // Value = 0xff +/// 0x11, 0x22, // Extra bytes +/// ]; +/// let buffer = LengthValueBuffer::new(&bytes).unwrap(); +/// assert_eq!(buffer.length(), 5); +/// assert_eq!(buffer.value_length(), 1); +/// assert_eq!(buffer.value(), &[0xff][..]); +/// ``` +/// +/// Encoding an LV field: +/// +/// ```rust +/// # use xain_fl::message::LengthValueBuffer; +/// let mut bytes = vec![0xff; 9]; +/// let mut buffer = LengthValueBuffer::new_unchecked(&mut bytes); +/// // It is important to set the length field before setting the value, otherwise, `value_mut()` will panic. +/// buffer.set_length(8); +/// buffer.value_mut().copy_from_slice(&[0, 1, 2, 3][..]); +/// let expected = vec![ +/// 0x00, 0x00, 0x00, 0x08, // Length = 8 +/// 0x00, 0x01, 0x02, 0x03, // Value +/// 0xff, // unchanged +/// ]; +/// +/// assert_eq!(bytes, expected); +/// ``` +pub struct LengthValueBuffer { + inner: T, +} + +/// Size of the length field for encoding a Length-Value item. +const LENGTH_FIELD: Range = 0..4; + +impl> LengthValueBuffer { + /// Return a new `LengthValueBuffer`. This method performs bound + /// checks and returns an error if the given buffer is not a valid + /// Length-Value item. + /// + /// # Example + /// + /// ```rust + /// # use xain_fl::message::LengthValueBuffer; + /// // truncated length: + /// assert!(LengthValueBuffer::new(&vec![0x00, 0x00, 0x00]).is_err()); + /// + /// // truncated value: + /// let bytes = vec![ + /// 0x00, 0x00, 0x00, 0x08, // length: 8 + /// 0x11, 0x22, 0x33, // value + /// ]; + /// assert!(LengthValueBuffer::new(&bytes).is_err()); + /// + /// // valid Length-Value item + /// let bytes = vec![ + /// 0x00, 0x00, 0x00, 0x08, // length: 8 + /// 0x11, 0x22, 0x33, 0x44, // value + /// 0xaa, 0xbb, // extra bytes are ignored + /// ]; + /// let buf = LengthValueBuffer::new(&bytes).unwrap(); + /// assert_eq!(buf.length(), 8); + /// assert_eq!(buf.value(), &[0x11, 0x22, 0x33, 0x44][..]); + /// ``` + pub fn new(bytes: T) -> Result { + let buffer = Self { inner: bytes }; + buffer + .check_buffer_length() + .context("not a valid LengthValueBuffer")?; + Ok(buffer) + } + + /// Create a new `LengthValueBuffer` without any bound check. + pub fn new_unchecked(bytes: T) -> Self { + Self { inner: bytes } + } + + /// Check that the buffer is a valid Length-Value item. + pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + let len = self.inner.as_ref().len(); + if len < LENGTH_FIELD.end { + return Err(anyhow!( + "invalid buffer length: {} < {}", + len, + LENGTH_FIELD.end + )); + } + + if (self.length() as usize) < LENGTH_FIELD.end { + return Err(anyhow!( + "invalid length value: {} (should be >= {})", + len, + LENGTH_FIELD.end + )); + } + + if len < self.length() as usize { + return Err(anyhow!( + "invalid buffer length: {} < {}", + len, + self.length(), + )); + } + Ok(()) + } + + /// Return the length field. Note that the value of the length + /// field includes the length of the field itself (4 bytes). + /// + /// # Panic + /// + /// This method may panic if buffer is not a valid Length-Value item. + pub fn length(&self) -> u32 { + // unwrap safe: the slice is exactly 4 bytes long + u32::from_be_bytes(self.inner.as_ref()[LENGTH_FIELD].try_into().unwrap()) + } + + /// Return the length of the value + pub fn value_length(&self) -> usize { + self.length() as usize - LENGTH_FIELD.end + } + + /// Return the range corresponding to the value + fn value_range(&self) -> Range { + let offset = LENGTH_FIELD.end; + let value_length = self.value_length(); + offset..offset + value_length + } +} + +impl> LengthValueBuffer { + /// Set the length field to the given value. + /// + /// # Panic + /// + /// This method may panic if buffer is not a valid Length-Value item. + pub fn set_length(&mut self, value: u32) { + self.inner.as_mut()[LENGTH_FIELD].copy_from_slice(&value.to_be_bytes()); + } +} + +// This impl that it differs from the usual mutable getters impl: +// +// 1. IT CONSUMES the buffer, so the LengthValueBuffer is a sort of +// "oneshot" buffer to safely access the payload of "length-value" +// field. We cannot take a reference to `self`, because the mutable +// slice the method returns would be bound to it, which means the +// LengthValueBuffer cannot be used within a function to return a +// mutable slice like so: +// +// fn my_length_value_field_mut(&self) -> &'a mut [u8] { +// // data is &'a +// let data = &mut self.inner.as_mut()[START..] +// // But the reference we're giving here is bound to &self +// let buf = LengthValueBuffer::new(&mut data).unwrap(); +// // this slice is bound to `buf`, which is a local variable... +// buf.value_mut() +// } +impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> LengthValueBuffer<&'a mut T> { + /// Get a mutable reference to the value field. + /// + /// # Panic + /// + /// This method may panic if buffer is not a valid Length-Value item. + pub fn value_mut(self) -> &'a mut [u8] { + let range = self.value_range(); + &mut self.inner.as_mut()[range] + } + + /// Get a mutable reference to the underlying buffer. + /// + /// # Panic + /// + /// This method may panic if buffer is not a valid Length-Value item. + pub fn bytes_mut(self) -> &'a mut [u8] { + &mut self.inner.as_mut()[..] + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> LengthValueBuffer<&'a T> { + /// Get a reference to the value field. + /// + /// # Panic + /// + /// This method may panic if buffer is not a valid Length-Value item. + pub fn value(&self) -> &'a [u8] { + &self.inner.as_ref()[self.value_range()] + } + + /// Get a reference to the underlying buffer. + /// + /// # Panic + /// + /// This method may panic if buffer is not a valid Length-Value item. + pub fn bytes(self) -> &'a [u8] { + let range = self.value_range(); + &self.inner.as_ref()[..range.end] + } +} + +macro_rules! impl_traits_for_length_value_types { + ($ty: ty) => { + impl ToBytes for $ty { + fn buffer_length(&self) -> usize { + LENGTH_FIELD.end + self.as_ref().len() + } + + fn to_bytes>(&self, buffer: &mut T) { + let mut writer = LengthValueBuffer::new_unchecked(buffer.as_mut()); + writer.set_length(self.buffer_length() as u32); + writer.value_mut().copy_from_slice(self.as_ref()); + } + } + + impl FromBytes for $ty { + fn from_bytes>(buffer: &T) -> Result { + let reader = LengthValueBuffer::new(buffer.as_ref())?; + Ok(Self::from(reader.value())) + } + } + }; +} + +impl_traits_for_length_value_types!(crate::certificate::Certificate); +impl_traits_for_length_value_types!(crate::mask::Mask); +impl_traits_for_length_value_types!(crate::mask::MaskedModel); + +const ENTRY_LENGTH: usize = SumParticipantPublicKey::LENGTH + EncrMaskSeed::BYTES; + +impl ToBytes for LocalSeedDict { + fn buffer_length(&self) -> usize { + LENGTH_FIELD.end + self.len() * ENTRY_LENGTH + } + + fn to_bytes>(&self, buffer: &mut T) { + let mut writer = Cursor::new(buffer.as_mut()); + let length = self.buffer_length() as u32; + writer.write(&length.to_be_bytes()).unwrap(); + for (key, value) in self { + writer.write(key.as_slice()).unwrap(); + writer.write(value.as_ref()).unwrap(); + } + } +} + +impl FromBytes for LocalSeedDict { + fn from_bytes>(buffer: &T) -> Result { + let reader = LengthValueBuffer::new(buffer.as_ref())?; + let nb_entries = reader.value_length() as usize / ENTRY_LENGTH; + let mut dict = LocalSeedDict::with_capacity(nb_entries); + + let key_length = SumParticipantPublicKey::LENGTH; + let mut entries = reader.value().chunks_exact(ENTRY_LENGTH); + for chunk in &mut entries { + // safe unwraps: lengths of slices are guaranteed + // by constants. + let key = SumParticipantPublicKey::from_slice(&chunk[..key_length]).unwrap(); + let value = EncrMaskSeed::try_from(&chunk[key_length..]).unwrap(); + if dict.insert(key, value).is_some() { + return Err(anyhow!("invalid local seed dictionary: duplicated key")); + } + } + if entries.remainder().len() > 0 { + return Err(anyhow!("invalid local seed dictionary: trailing bytes")); + } + Ok(dict) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn decode_length_value_buffer() { + let bytes = vec![ + 0x00, 0x00, 0x00, 0x05, // Length = 1 + 0xff, // Value = 0xff + 0x11, 0x22, // Extra bytes + ]; + let buffer = LengthValueBuffer::new(&bytes).unwrap(); + assert_eq!(buffer.length(), 5); + assert_eq!(buffer.value_length(), 1); + assert_eq!(buffer.value(), &[0xff][..]); + } + + #[test] + fn decode_empty_value() { + let bytes = vec![0x00, 0x00, 0x00, 0x04]; + let buffer = LengthValueBuffer::new(&bytes).unwrap(); + assert_eq!(buffer.length(), 4); + assert_eq!(buffer.value_length(), 0); + } + + #[test] + fn decode_length_value_buffer_buffer_exhausted() { + let bytes = vec![ + 0x00, 0x00, 0x00, 0x08, // Length = 6 + 0x11, 0x22, // Only 2 bytes + ]; + assert!(LengthValueBuffer::new(bytes).is_err()); + } + + #[test] + fn decode_length_value_buffer_invalid_length() { + // Missing bytes + let bytes = vec![0x00, 0x00, 0x00]; + assert!(LengthValueBuffer::new(bytes).is_err()); + // Length field invalid + let bytes = vec![0x00, 0x00, 0x00, 0x03]; + assert!(LengthValueBuffer::new(bytes).is_err()); + } + + #[test] + fn encode_length_value_buffer() { + let mut bytes = vec![0xff; 7]; + let mut buffer = LengthValueBuffer::new_unchecked(&mut bytes); + buffer.set_length(6); + buffer.value_mut().copy_from_slice(&[0x11, 0x22][..]); + let expected = vec![ + 0x00, 0x00, 0x00, 0x06, // Length = 6 + 0x11, 0x22, // Value + 0xff, // unchanged + ]; + + assert_eq!(bytes, expected); + } + + #[test] + fn encode_length_value_buffer_emty() { + let mut bytes = vec![0xff; 5]; + let mut buffer = LengthValueBuffer::new_unchecked(&mut bytes); + buffer.set_length(4); + buffer.value_mut().copy_from_slice(&[][..]); + let expected = vec![ + 0x00, 0x00, 0x00, 0x04, // Length = 0 + 0xff, // unchanged + ]; + + assert_eq!(bytes, expected); + } +} diff --git a/rust/src/message/update.rs b/rust/src/message/update.rs deleted file mode 100644 index a701c1252..000000000 --- a/rust/src/message/update.rs +++ /dev/null @@ -1,759 +0,0 @@ -use std::{ - borrow::Borrow, - convert::{TryFrom, TryInto}, - ops::Range, -}; - -use super::{MessageBuffer, Tag, LEN_BYTES, PK_BYTES, SIGNATURE_BYTES}; -use crate::{ - certificate::Certificate, - crypto::{ByteObject, Signature}, - mask::{seed::EncryptedMaskSeed, Integers, MaskedModel}, - CoordinatorPublicKey, - CoordinatorSecretKey, - LocalSeedDict, - ParticipantTaskSignature, - PetError, - SumParticipantPublicKey, - UpdateParticipantPublicKey, - UpdateParticipantSecretKey, -}; - -#[derive(Clone, Debug)] -/// Access to update message buffer fields. -struct UpdateMessageBuffer { - bytes: B, - certificate_range: Range, - masked_model_range: Range, - local_seed_dict_range: Range, -} - -impl UpdateMessageBuffer> { - /// Create an empty update message buffer. - fn new(certificate_len: usize, masked_model_len: usize, local_seed_dict_len: usize) -> Self { - let bytes = [ - vec![0_u8; Self::UPDATE_SIGNATURE_RANGE.end], - certificate_len.to_le_bytes().to_vec(), - masked_model_len.to_le_bytes().to_vec(), - local_seed_dict_len.to_le_bytes().to_vec(), - vec![0_u8; certificate_len + masked_model_len + local_seed_dict_len], - ] - .concat(); - let certificate_range = Self::LOCAL_SEED_DICT_LEN_RANGE.end - ..Self::LOCAL_SEED_DICT_LEN_RANGE.end + certificate_len; - let masked_model_range = certificate_range.end..certificate_range.end + masked_model_len; - let local_seed_dict_range = - masked_model_range.end..masked_model_range.end + local_seed_dict_len; - Self { - bytes, - certificate_range, - masked_model_range, - local_seed_dict_range, - } - } -} - -impl TryFrom> for UpdateMessageBuffer> { - type Error = PetError; - - /// Create an update message buffer from `bytes`. Fails if the length of the input is invalid. - fn try_from(bytes: Vec) -> Result { - let mut buffer = Self { - bytes, - certificate_range: 0..0, - masked_model_range: 0..0, - local_seed_dict_range: 0..0, - }; - if buffer.len() >= Self::LOCAL_SEED_DICT_LEN_RANGE.end { - // safe unwraps: lengths of slices are guaranteed by constants - buffer.certificate_range = Self::LOCAL_SEED_DICT_LEN_RANGE.end - ..Self::LOCAL_SEED_DICT_LEN_RANGE.end - + usize::from_le_bytes(buffer.certificate_len().try_into().unwrap()); - buffer.masked_model_range = buffer.certificate_range.end - ..buffer.certificate_range.end - + usize::from_le_bytes(buffer.masked_model_len().try_into().unwrap()); - buffer.local_seed_dict_range = buffer.masked_model_range.end - ..buffer.masked_model_range.end - + usize::from_le_bytes(buffer.local_seed_dict_len().try_into().unwrap()); - } else { - return Err(PetError::InvalidMessage); - } - if buffer.len() == buffer.local_seed_dict_range.end { - Ok(buffer) - } else { - Err(PetError::InvalidMessage) - } - } -} - -impl + AsMut<[u8]>> MessageBuffer for UpdateMessageBuffer { - /// Get a reference to the message buffer. - fn bytes(&'_ self) -> &'_ [u8] { - self.bytes.as_ref() - } - - /// Get a mutable reference to the message buffer. - fn bytes_mut(&mut self) -> &mut [u8] { - self.bytes.as_mut() - } -} - -impl + AsMut<[u8]>> UpdateMessageBuffer { - /// Get the range of the update signature field. - const UPDATE_SIGNATURE_RANGE: Range = - Self::SUM_SIGNATURE_RANGE.end..Self::SUM_SIGNATURE_RANGE.end + SIGNATURE_BYTES; - - /// Get the range of the certificate length field. - const CERTIFICATE_LEN_RANGE: Range = - Self::UPDATE_SIGNATURE_RANGE.end..Self::UPDATE_SIGNATURE_RANGE.end + LEN_BYTES; - - /// Get the range of the masked model length field. - const MASKED_MODEL_LEN_RANGE: Range = - Self::CERTIFICATE_LEN_RANGE.end..Self::CERTIFICATE_LEN_RANGE.end + LEN_BYTES; - - /// Get the range of the local seed dictionary length field. - const LOCAL_SEED_DICT_LEN_RANGE: Range = - Self::MASKED_MODEL_LEN_RANGE.end..Self::MASKED_MODEL_LEN_RANGE.end + LEN_BYTES; - - /// Get a reference to the update signature field. - fn update_signature(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::UPDATE_SIGNATURE_RANGE] - } - - /// Get a mutable reference to the update signature field. - fn update_signature_mut(&mut self) -> &mut [u8] { - &mut self.bytes_mut()[Self::UPDATE_SIGNATURE_RANGE] - } - - /// Get a reference to the certificate length field. - fn certificate_len(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::CERTIFICATE_LEN_RANGE] - } - - /// Get a reference to the certificate field. - fn certificate(&'_ self) -> &'_ [u8] { - &self.bytes()[self.certificate_range.clone()] - } - - /// Get a mutable reference to the certificate field. - fn certificate_mut(&mut self) -> &mut [u8] { - let range = self.certificate_range.clone(); - &mut self.bytes_mut()[range] - } - - /// Get a reference to the masked model length field. - fn masked_model_len(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::MASKED_MODEL_LEN_RANGE] - } - - /// Get a reference to the masked model field. - fn masked_model(&'_ self) -> &'_ [u8] { - &self.bytes()[self.masked_model_range.clone()] - } - - /// Get a mutable reference to the masked model field. - fn masked_model_mut(&mut self) -> &mut [u8] { - let range = self.masked_model_range.clone(); - &mut self.bytes_mut()[range] - } - - /// Get a reference to the local seed dictionary length field. - fn local_seed_dict_len(&'_ self) -> &'_ [u8] { - &self.bytes()[Self::LOCAL_SEED_DICT_LEN_RANGE] - } - - /// Get a reference to the local seed dictionary field. - fn local_seed_dict(&'_ self) -> &'_ [u8] { - &self.bytes()[self.local_seed_dict_range.clone()] - } - - /// Get a mutable reference to the local seed dictionary field. - fn local_seed_dict_mut(&mut self) -> &mut [u8] { - let range = self.local_seed_dict_range.clone(); - &mut self.bytes_mut()[range] - } -} - -#[derive(Clone, Debug, PartialEq)] -/// Encryption and decryption of update messages. -pub struct UpdateMessage -where - K: Borrow, - S: Borrow, - C: Borrow, - M: Borrow, - D: Borrow, -{ - pk: K, - sum_signature: S, - update_signature: S, - certificate: C, - masked_model: M, - local_seed_dict: D, -} - -impl UpdateMessage -where - K: Borrow, - S: Borrow, - C: Borrow, - M: Borrow, - D: Borrow, -{ - /// Create an update message from its parts. - pub fn from_parts( - pk: K, - sum_signature: S, - update_signature: S, - certificate: C, - masked_model: M, - local_seed_dict: D, - ) -> Self { - Self { - pk, - sum_signature, - update_signature, - certificate, - masked_model, - local_seed_dict, - } - } - - /// Serialize the local seed dictionary into bytes. - fn serialize_local_seed_dict(&self) -> Vec { - self.local_seed_dict - .borrow() - .iter() - .flat_map(|(pk, seed)| [pk.as_slice(), seed.as_ref()].concat()) - .collect::>() - } - - /// Serialize the update message into a buffer. - fn serialize + AsMut<[u8]>>( - &self, - buffer: &mut UpdateMessageBuffer, - pk: &CoordinatorPublicKey, - ) { - buffer - .tag_mut() - .copy_from_slice([Tag::Update as u8].as_ref()); - buffer - .coord_pk_mut() - .copy_from_slice(pk.borrow().as_slice()); - buffer - .part_pk_mut() - .copy_from_slice(self.pk.borrow().as_slice()); - buffer - .sum_signature_mut() - .copy_from_slice(self.sum_signature.borrow().as_slice()); - buffer - .update_signature_mut() - .copy_from_slice(self.update_signature.borrow().as_slice()); - buffer - .certificate_mut() - .copy_from_slice(self.certificate.borrow().as_ref()); - buffer - .masked_model_mut() - .copy_from_slice(self.masked_model.borrow().serialize().as_slice()); - buffer - .local_seed_dict_mut() - .copy_from_slice(self.serialize_local_seed_dict().as_slice()); - } - - /// Sign and encrypt the update message. - pub fn seal(&self, sk: &UpdateParticipantSecretKey, pk: &CoordinatorPublicKey) -> Vec { - let mut buffer = UpdateMessageBuffer::new( - self.certificate.borrow().len(), - self.masked_model.borrow().len(), - self.local_seed_dict.borrow().len() * (PK_BYTES + EncryptedMaskSeed::BYTES), - ); - self.serialize(&mut buffer, pk); - let signature = sk.sign_detached(buffer.message()); - buffer.signature_mut().copy_from_slice(signature.as_slice()); - pk.encrypt(buffer.bytes()) - } -} - -impl - UpdateMessage< - UpdateParticipantPublicKey, - ParticipantTaskSignature, - Certificate, - MaskedModel, - LocalSeedDict, - > -{ - /// Deserialize a local seed dictionary from bytes. Fails if the length of the input is invalid. - fn deserialize_local_seed_dict(bytes: &[u8]) -> Result { - if bytes.len() % (PK_BYTES + EncryptedMaskSeed::BYTES) == 0 { - let local_seed_dict = bytes - .chunks_exact(PK_BYTES + EncryptedMaskSeed::BYTES) - .map(|chunk| { - if let (Some(pk), Some(seed)) = ( - SumParticipantPublicKey::from_slice(&chunk[..PK_BYTES]), - EncryptedMaskSeed::from_slice(&chunk[PK_BYTES..]), - ) { - Ok((pk, seed)) - } else { - Err(PetError::InvalidMessage) - } - }) - .collect::>()?; - Ok(local_seed_dict) - } else { - Err(PetError::InvalidMessage) - } - } - - /// Deserialize an update message from a buffer. Fails if the length of a part is invalid. - fn deserialize(buffer: UpdateMessageBuffer>) -> Result { - let pk = UpdateParticipantPublicKey::from_slice(buffer.part_pk()) - .ok_or(PetError::InvalidMessage)?; - let sum_signature = - Signature::from_slice(buffer.sum_signature()).ok_or(PetError::InvalidMessage)?; - let update_signature = - Signature::from_slice(buffer.update_signature()).ok_or(PetError::InvalidMessage)?; - let certificate = Certificate::deserialize(buffer.certificate())?; - let masked_model = MaskedModel::deserialize(buffer.masked_model())?; - let local_seed_dict = Self::deserialize_local_seed_dict(buffer.local_seed_dict())?; - Ok(Self { - pk, - sum_signature, - update_signature, - certificate, - masked_model, - local_seed_dict, - }) - } - - /// Decrypt and verify an update message. Fails if decryption or validation fails. - pub fn open( - bytes: &[u8], - pk: &CoordinatorPublicKey, - sk: &CoordinatorSecretKey, - ) -> Result { - let buffer = UpdateMessageBuffer::try_from( - sk.decrypt(bytes, pk).or(Err(PetError::InvalidMessage))?, - )?; - if buffer.tag() == [Tag::Update as u8] - && buffer.coord_pk() == pk.as_slice() - && UpdateParticipantPublicKey::from_slice(buffer.part_pk()) - .ok_or(PetError::InvalidMessage)? - .verify_detached( - &Signature::from_slice(buffer.signature()).ok_or(PetError::InvalidMessage)?, - buffer.message(), - ) - { - Ok(Self::deserialize(buffer)?) - } else { - Err(PetError::InvalidMessage) - } - } - - derive_struct_fields!( - pk, SumParticipantPublicKey; - sum_signature, ParticipantTaskSignature; - update_signature, ParticipantTaskSignature; - certificate, Certificate; - masked_model, MaskedModel; - local_seed_dict, LocalSeedDict; - ); -} - -#[cfg(test)] -mod tests { - use std::iter; - - use rand::SeedableRng; - use rand_chacha::ChaCha20Rng; - use sodiumoxide::randombytes::{randombytes, randombytes_uniform}; - - use super::*; - use crate::{ - crypto::{generate_encrypt_key_pair, generate_integer, generate_signing_key_pair}, - mask::{ - config::{BoundType, DataType, GroupType, MaskConfigs, ModelType}, - MaskedModel, - }, - message::TAG_BYTES, - }; - - type MB = UpdateMessageBuffer>; - - fn auxiliary_bytes(sum_dict_len: usize) -> Vec { - let masked_model = auxiliary_masked_model(); - [ - randombytes(257), - (32 as usize).to_le_bytes().to_vec(), - masked_model.len().to_le_bytes().to_vec(), - (112 * sum_dict_len as usize).to_le_bytes().to_vec(), - randombytes(32), - masked_model.serialize(), - randombytes(112 * sum_dict_len), - ] - .concat() - } - - fn auxiliary_masked_model() -> MaskedModel { - let mut prng = ChaCha20Rng::from_seed([0_u8; 32]); - let config = MaskConfigs::from_parts( - GroupType::Prime, - DataType::F32, - BoundType::B0, - ModelType::M3, - ) - .config(); - let integers = iter::repeat_with(|| generate_integer(&mut prng, config.order())) - .take(10) - .collect(); - MaskedModel::from_parts(integers, config).unwrap() - } - - #[test] - fn test_updatemessagebuffer_ranges() { - assert_eq!(MB::SIGNATURE_RANGE, ..SIGNATURE_BYTES); - assert_eq!(MB::MESSAGE_RANGE, SIGNATURE_BYTES..); - assert_eq!(MB::TAG_RANGE, 64..64 + TAG_BYTES); - assert_eq!(MB::COORD_PK_RANGE, 65..65 + PK_BYTES); - assert_eq!(MB::PART_PK_RANGE, 97..97 + PK_BYTES); - assert_eq!(MB::SUM_SIGNATURE_RANGE, 129..129 + SIGNATURE_BYTES); - assert_eq!(MB::UPDATE_SIGNATURE_RANGE, 193..193 + SIGNATURE_BYTES); - assert_eq!(MB::CERTIFICATE_LEN_RANGE, 257..257 + LEN_BYTES); - assert_eq!( - MB::MASKED_MODEL_LEN_RANGE, - 257 + LEN_BYTES..257 + 2 * LEN_BYTES, - ); - assert_eq!( - MB::LOCAL_SEED_DICT_LEN_RANGE, - 257 + 2 * LEN_BYTES..257 + 3 * LEN_BYTES, - ); - let sum_dict_len = 1 + randombytes_uniform(10) as usize; - let buffer = UpdateMessageBuffer::new(32, 32, 112 * sum_dict_len); - assert_eq!( - buffer.certificate_range, - 257 + 3 * LEN_BYTES..257 + 3 * LEN_BYTES + 32, - ); - assert_eq!( - buffer.masked_model_range, - 257 + 3 * LEN_BYTES + 32..257 + 3 * LEN_BYTES + 32 + 32, - ); - assert_eq!( - buffer.local_seed_dict_range, - 257 + 3 * LEN_BYTES + 32 + 32..257 + 3 * LEN_BYTES + 32 + 32 + 112 * sum_dict_len, - ); - } - - #[test] - fn test_updatemessagebuffer_fields() { - // new - let sum_dict_len = 1 + randombytes_uniform(10) as usize; - assert_eq!( - UpdateMessageBuffer::new(32, 32, 112 * sum_dict_len).bytes, - [ - vec![0_u8; 257], - (32 as usize).to_le_bytes().to_vec(), - (32 as usize).to_le_bytes().to_vec(), - (112 * sum_dict_len as usize).to_le_bytes().to_vec(), - vec![0_u8; 64 + 112 * sum_dict_len], - ] - .concat(), - ); - - // try from - let mut bytes = auxiliary_bytes(sum_dict_len); - let mut buffer = UpdateMessageBuffer::try_from(bytes.clone()).unwrap(); - assert_eq!(buffer.bytes, bytes); - assert_eq!( - UpdateMessageBuffer::try_from(vec![0_u8; 0]).unwrap_err(), - PetError::InvalidMessage, - ); - - // length - assert_eq!(buffer.len(), 353 + 112 * sum_dict_len + 3 * LEN_BYTES); - - // signature - assert_eq!(buffer.signature(), &bytes[MB::SIGNATURE_RANGE]); - assert_eq!(buffer.signature_mut(), &mut bytes[MB::SIGNATURE_RANGE]); - - // message - assert_eq!(buffer.message(), &bytes[MB::MESSAGE_RANGE]); - - // tag - assert_eq!(buffer.tag(), &bytes[MB::TAG_RANGE]); - assert_eq!(buffer.tag_mut(), &mut bytes[MB::TAG_RANGE]); - - // coordinator pk - assert_eq!(buffer.coord_pk(), &bytes[MB::COORD_PK_RANGE]); - assert_eq!(buffer.coord_pk_mut(), &mut bytes[MB::COORD_PK_RANGE]); - - // participant pk - assert_eq!(buffer.part_pk(), &bytes[MB::PART_PK_RANGE]); - assert_eq!(buffer.part_pk_mut(), &mut bytes[MB::PART_PK_RANGE]); - - // sum signature - assert_eq!(buffer.sum_signature(), &bytes[MB::SUM_SIGNATURE_RANGE]); - assert_eq!( - buffer.sum_signature_mut(), - &mut bytes[MB::SUM_SIGNATURE_RANGE], - ); - - // update signature - assert_eq!( - buffer.update_signature(), - &bytes[MB::UPDATE_SIGNATURE_RANGE], - ); - assert_eq!( - buffer.update_signature_mut(), - &mut bytes[MB::UPDATE_SIGNATURE_RANGE], - ); - - // certificate - assert_eq!(buffer.certificate_len(), &bytes[MB::CERTIFICATE_LEN_RANGE]); - let range = buffer.certificate_range.clone(); - assert_eq!(buffer.certificate(), &bytes[range.clone()]); - assert_eq!(buffer.certificate_mut(), &mut bytes[range]); - - // masked model - assert_eq!( - buffer.masked_model_len(), - &bytes[MB::MASKED_MODEL_LEN_RANGE], - ); - let range = buffer.masked_model_range.clone(); - assert_eq!(buffer.masked_model(), &bytes[range.clone()]); - assert_eq!(buffer.masked_model_mut(), &mut bytes[range]); - - // local seed dictionary - assert_eq!( - buffer.local_seed_dict_len(), - &bytes[MB::LOCAL_SEED_DICT_LEN_RANGE], - ); - let range = buffer.local_seed_dict_range.clone(); - assert_eq!(buffer.local_seed_dict(), &bytes[range.clone()]); - assert_eq!(buffer.local_seed_dict_mut(), &mut bytes[range]); - } - - #[test] - fn test_updatemessage_serialize() { - // from parts - let sum_dict_len = 1 + randombytes_uniform(10) as usize; - let pk = &UpdateParticipantPublicKey::from_slice_unchecked(randombytes(32).as_slice()); - let sum_signature = &Signature::from_slice_unchecked(randombytes(64).as_slice()); - let update_signature = &Signature::from_slice_unchecked(randombytes(64).as_slice()); - let certificate = &Certificate::zeroed(); - let masked_model = &auxiliary_masked_model(); - let local_seed_dict = &iter::repeat_with(|| { - ( - SumParticipantPublicKey::from_slice_unchecked(randombytes(32).as_slice()), - EncryptedMaskSeed::from_slice_unchecked(randombytes(80).as_slice()), - ) - }) - .take(sum_dict_len) - .collect(); - let msg = UpdateMessage::from_parts( - pk, - sum_signature, - update_signature, - certificate, - masked_model, - local_seed_dict, - ); - assert_eq!( - msg.pk as *const UpdateParticipantPublicKey, - pk as *const UpdateParticipantPublicKey, - ); - assert_eq!( - msg.sum_signature as *const Signature, - sum_signature as *const Signature, - ); - assert_eq!( - msg.update_signature as *const Signature, - update_signature as *const Signature, - ); - assert_eq!( - msg.certificate as *const Certificate, - certificate as *const Certificate, - ); - assert_eq!( - msg.masked_model as *const MaskedModel, - masked_model as *const MaskedModel - ); - assert_eq!( - msg.local_seed_dict as *const LocalSeedDict, - local_seed_dict as *const LocalSeedDict, - ); - - // serialize seed dictionary - let local_seed_vec = msg.serialize_local_seed_dict(); - assert_eq!( - local_seed_vec.len(), - (PK_BYTES + EncryptedMaskSeed::BYTES) * sum_dict_len - ); - assert!(local_seed_vec - .chunks_exact(PK_BYTES + EncryptedMaskSeed::BYTES) - .all(|chunk| { - local_seed_dict - .get(&SumParticipantPublicKey::from_slice_unchecked( - &chunk[..PK_BYTES], - )) - .unwrap() - .as_slice() - == &chunk[PK_BYTES..] - })); - - // serialize - let mut buffer = UpdateMessageBuffer::new(32, masked_model.len(), 112 * sum_dict_len); - let coord_pk = CoordinatorPublicKey::from_slice_unchecked(randombytes(32).as_slice()); - msg.serialize(&mut buffer, &coord_pk); - assert_eq!(buffer.tag(), [Tag::Update as u8].as_ref()); - assert_eq!(buffer.coord_pk(), coord_pk.as_slice()); - assert_eq!(buffer.part_pk(), pk.as_slice()); - assert_eq!(buffer.sum_signature(), sum_signature.as_slice()); - assert_eq!(buffer.update_signature(), update_signature.as_slice()); - assert_eq!( - buffer.certificate_len(), - certificate.len().to_le_bytes().as_ref(), - ); - assert_eq!(buffer.certificate(), certificate.as_slice()); - assert_eq!( - buffer.masked_model_len(), - masked_model.len().to_le_bytes().as_ref(), - ); - assert_eq!(buffer.masked_model(), masked_model.serialize().as_slice()); - assert_eq!( - buffer.local_seed_dict_len(), - (112 * sum_dict_len as usize).to_le_bytes().as_ref(), - ); - assert_eq!(buffer.local_seed_dict(), local_seed_vec.as_slice()); - } - - #[test] - fn test_updatemessage_deserialize() { - // deserialize seed dictionary - let sum_dict_len = 1 + randombytes_uniform(10) as usize; - let local_seed_vec = randombytes((PK_BYTES + EncryptedMaskSeed::BYTES) * sum_dict_len); - let local_seed_dict = UpdateMessage::deserialize_local_seed_dict(&local_seed_vec).unwrap(); - for chunk in local_seed_vec.chunks_exact(PK_BYTES + EncryptedMaskSeed::BYTES) { - assert_eq!( - local_seed_dict - .get(&SumParticipantPublicKey::from_slice_unchecked( - &chunk[..PK_BYTES] - )) - .unwrap(), - &EncryptedMaskSeed::from_slice_unchecked(&chunk[PK_BYTES..]), - ); - } - - // deserialize - let bytes = auxiliary_bytes(sum_dict_len); - let buffer = UpdateMessageBuffer::try_from(bytes.clone()).unwrap(); - let msg = UpdateMessage::deserialize(buffer.clone()).unwrap(); - assert_eq!( - msg.pk(), - &UpdateParticipantPublicKey::from_slice_unchecked(&bytes[MB::PART_PK_RANGE]), - ); - assert_eq!( - msg.sum_signature(), - &Signature::from_slice_unchecked(&bytes[MB::SUM_SIGNATURE_RANGE]), - ); - assert_eq!( - msg.update_signature(), - &Signature::from_slice_unchecked(&bytes[MB::UPDATE_SIGNATURE_RANGE]), - ); - assert_eq!( - msg.certificate(), - &Certificate::deserialize(&bytes[buffer.certificate_range.clone()]).unwrap() - ); - assert_eq!( - msg.masked_model(), - &MaskedModel::deserialize(&bytes[buffer.masked_model_range.clone()]).unwrap(), - ); - assert_eq!( - msg.local_seed_dict(), - &UpdateMessage::deserialize_local_seed_dict( - &bytes[buffer.local_seed_dict_range.clone()] - ) - .unwrap(), - ); - } - - #[test] - fn test_updatemessage() { - // seal - let sum_dict_len = 1 + randombytes_uniform(10) as usize; - let (pk, sk) = generate_signing_key_pair(); - let sum_signature = Signature::from_slice_unchecked(randombytes(64).as_slice()); - let update_signature = Signature::from_slice_unchecked(randombytes(64).as_slice()); - let certificate = Certificate::zeroed(); - let masked_model = auxiliary_masked_model(); - let local_seed_dict = iter::repeat_with(|| { - ( - SumParticipantPublicKey::from_slice_unchecked(randombytes(32).as_slice()), - EncryptedMaskSeed::from_slice_unchecked(randombytes(80).as_slice()), - ) - }) - .take(sum_dict_len) - .collect(); - let (coord_pk, coord_sk) = generate_encrypt_key_pair(); - let bytes = UpdateMessage::from_parts( - &pk, - &sum_signature, - &update_signature, - &certificate, - &masked_model, - &local_seed_dict, - ) - .seal(&sk, &coord_pk); - - // open - let msg = UpdateMessage::open(&bytes, &coord_pk, &coord_sk).unwrap(); - assert_eq!(msg.pk(), &pk); - assert_eq!(msg.sum_signature(), &sum_signature); - assert_eq!(msg.update_signature(), &update_signature); - assert_eq!(msg.certificate(), &certificate); - assert_eq!(msg.masked_model(), &masked_model); - assert_eq!(msg.local_seed_dict(), &local_seed_dict); - - // wrong signature - let bytes = auxiliary_bytes(sum_dict_len); - let mut buffer = UpdateMessageBuffer::try_from(bytes).unwrap(); - let msg = UpdateMessage::from_parts( - &pk, - &sum_signature, - &update_signature, - &certificate, - &masked_model, - &local_seed_dict, - ); - msg.serialize(&mut buffer, &coord_pk); - let bytes = coord_pk.encrypt(buffer.bytes()); - assert_eq!( - UpdateMessage::open(&bytes, &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - - // wrong receiver - msg.serialize( - &mut buffer, - &CoordinatorPublicKey::from_slice_unchecked(randombytes(32).as_slice()), - ); - let bytes = coord_pk.encrypt(buffer.bytes()); - assert_eq!( - UpdateMessage::open(&bytes, &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - - // wrong tag - buffer.tag_mut().copy_from_slice([Tag::None as u8].as_ref()); - let bytes = coord_pk.encrypt(buffer.bytes()); - assert_eq!( - UpdateMessage::open(&bytes, &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - - // wrong length - assert_eq!( - UpdateMessage::open([0_u8; 0].as_ref(), &coord_pk, &coord_sk).unwrap_err(), - PetError::InvalidMessage, - ); - } -} diff --git a/rust/src/message/utils.rs b/rust/src/message/utils.rs new file mode 100644 index 000000000..3531aaac2 --- /dev/null +++ b/rust/src/message/utils.rs @@ -0,0 +1,5 @@ +use std::ops::Range; + +pub(crate) const fn range(start: usize, length: usize) -> Range { + start..(start + length) +} diff --git a/rust/src/participant.rs b/rust/src/participant.rs index 20fe260f3..99698a596 100644 --- a/rust/src/participant.rs +++ b/rust/src/participant.rs @@ -103,13 +103,14 @@ impl Participant { /// Compose a sum message. pub fn compose_sum_message(&mut self, pk: &CoordinatorPublicKey) -> Vec { self.gen_ephm_keypair(); - SumMessage::from_parts( - &self.pk, - &self.sum_signature, - &self.ephm_pk, - &self.certificate, - ) - .seal(&self.sk, pk) + + let payload = SumOwned { + sum_signature: self.sum_signature, + ephm_pk: self.ephm_pk, + }; + + let message = MessageOwned::new_sum(*pk, self.pk, payload); + self.seal_message(pk, &message) } /// Compose an update message. @@ -126,21 +127,22 @@ impl Participant { // safe unwrap: data types of model and mask configuration conform due to definition above let (mask_seed, masked_model) = model.mask(scalar, &mask_config); let local_seed_dict = Self::create_local_seed_dict(sum_dict, &mask_seed); - UpdateMessage::from_parts( - &self.pk, - &self.sum_signature, - &self.update_signature, - &self.certificate, - &masked_model, - &local_seed_dict, - ) - .seal(&self.sk, pk) + + let payload = UpdateOwned { + sum_signature: self.sum_signature, + update_signature: self.update_signature, + masked_model: masked_model, + local_seed_dict: local_seed_dict, + }; + + let message = MessageOwned::new_update(pk, self.pk, payload); + self.seal_message(&pk, &message) } /// Compose a sum2 message. pub fn compose_sum2_message( &self, - pk: &CoordinatorPublicKey, + pk: CoordinatorPublicKey, seed_dict: &SeedDict, ) -> Result, PetError> { let mask_seeds = self.get_seeds(seed_dict)?; diff --git a/rust/src/service/data.rs b/rust/src/service/data.rs index a1c73bbaa..31e615374 100644 --- a/rust/src/service/data.rs +++ b/rust/src/service/data.rs @@ -1,4 +1,3 @@ -use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::{ @@ -7,11 +6,9 @@ use crate::{ service::handle::{SerializedSeedDict, SerializedSumDict}, MaskHash, SeedDict, - SumDict, SumParticipantPublicKey, }; use derive_more::From; -use sodiumoxide::crypto::box_; use std::{collections::HashMap, sync::Arc}; /// Data that the service keeps track of. diff --git a/rust/src/service/handle.rs b/rust/src/service/handle.rs index debeb1a92..741385a06 100644 --- a/rust/src/service/handle.rs +++ b/rust/src/service/handle.rs @@ -1,7 +1,6 @@ use crate::{coordinator::RoundParameters, SumParticipantPublicKey}; use derive_more::From; use std::{ - collections::HashMap, pin::Pin, sync::Arc, task::{Context, Poll}, diff --git a/rust/src/service/mod.rs b/rust/src/service/mod.rs index a6f9e8e5e..c475c8d2f 100644 --- a/rust/src/service/mod.rs +++ b/rust/src/service/mod.rs @@ -4,22 +4,12 @@ use crate::{ }; use derive_more::From; use futures::ready; -use sodiumoxide::crypto::box_; use std::{ - collections::HashMap, - error::Error, future::Future, pin::Pin, - sync::Arc, task::{Context, Poll}, }; -use tokio::{ - stream::Stream, - sync::{ - mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, - oneshot, - }, -}; +use tokio::stream::Stream; mod data; mod handle;