diff --git a/changelog.d/5-internal/mls-messages b/changelog.d/5-internal/mls-messages new file mode 100644 index 0000000000..45f0812ab0 --- /dev/null +++ b/changelog.d/5-internal/mls-messages @@ -0,0 +1 @@ +Add MLS message types and corresponding deserialisers diff --git a/libs/wire-api/package.yaml b/libs/wire-api/package.yaml index 02215683ab..e809cb0758 100644 --- a/libs/wire-api/package.yaml +++ b/libs/wire-api/package.yaml @@ -114,6 +114,7 @@ tests: - cassava - currency-codes - directory + - either - hex - iso3166-country-codes - iso639 diff --git a/libs/wire-api/src/Wire/API/MLS/Commit.hs b/libs/wire-api/src/Wire/API/MLS/Commit.hs new file mode 100644 index 0000000000..22c14dda9d --- /dev/null +++ b/libs/wire-api/src/Wire/API/MLS/Commit.hs @@ -0,0 +1,55 @@ +-- This file is part of the Wire Server implementation. +-- +-- Copyright (C) 2022 Wire Swiss GmbH +-- +-- This program is free software: you can redistribute it and/or modify it under +-- the terms of the GNU Affero General Public License as published by the Free +-- Software Foundation, either version 3 of the License, or (at your option) any +-- later version. +-- +-- This program is distributed in the hope that it will be useful, but WITHOUT +-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +-- details. +-- +-- You should have received a copy of the GNU Affero General Public License along +-- with this program. If not, see . + +module Wire.API.MLS.Commit where + +import Imports +import Wire.API.MLS.KeyPackage +import Wire.API.MLS.Proposal +import Wire.API.MLS.Serialisation + +data Commit = Commit + { cProposals :: [ProposalOrRef], + cPath :: Maybe UpdatePath + } + +instance ParseMLS Commit where + parseMLS = Commit <$> parseMLSVector @Word32 parseMLS <*> parseMLSOptional parseMLS + +data UpdatePath = UpdatePath + { upLeaf :: KeyPackage, + upNodes :: [UpdatePathNode] + } + +instance ParseMLS UpdatePath where + parseMLS = UpdatePath <$> parseMLS <*> parseMLSVector @Word32 parseMLS + +data UpdatePathNode = UpdatePathNode + { upnPublicKey :: ByteString, + upnSecret :: [HPKECiphertext] + } + +instance ParseMLS UpdatePathNode where + parseMLS = UpdatePathNode <$> parseMLSBytes @Word16 <*> parseMLSVector @Word32 parseMLS + +data HPKECiphertext = HPKECiphertext + { hcOutput :: ByteString, + hcCiphertext :: ByteString + } + +instance ParseMLS HPKECiphertext where + parseMLS = HPKECiphertext <$> parseMLSBytes @Word16 <*> parseMLSBytes @Word16 diff --git a/libs/wire-api/src/Wire/API/MLS/Credential.hs b/libs/wire-api/src/Wire/API/MLS/Credential.hs index 2922ca76e4..2db0616336 100644 --- a/libs/wire-api/src/Wire/API/MLS/Credential.hs +++ b/libs/wire-api/src/Wire/API/MLS/Credential.hs @@ -43,21 +43,20 @@ data Credential = BasicCredential deriving stock (Eq, Show, Generic) deriving (Arbitrary) via GenericUniform Credential -data CredentialTag = ReservedCredentialTag | BasicCredentialTag - deriving stock (Enum, Bounded, Show) - deriving (ParseMLS) via (EnumMLS Word16 CredentialTag) +data CredentialTag = BasicCredentialTag + deriving stock (Enum, Bounded, Eq, Show) + +instance ParseMLS CredentialTag where + parseMLS = parseMLSEnum @Word16 "credential type" instance ParseMLS Credential where - parseMLS = do - tag <- parseMLS - case tag of + parseMLS = + parseMLS >>= \case BasicCredentialTag -> BasicCredential <$> parseMLSBytes @Word16 <*> parseMLS <*> parseMLSBytes @Word16 - ReservedCredentialTag -> - fail "Unexpected credential type" credentialTag :: Credential -> CredentialTag credentialTag (BasicCredential _ _ _) = BasicCredentialTag diff --git a/libs/wire-api/src/Wire/API/MLS/Group.hs b/libs/wire-api/src/Wire/API/MLS/Group.hs new file mode 100644 index 0000000000..f7a5d9d824 --- /dev/null +++ b/libs/wire-api/src/Wire/API/MLS/Group.hs @@ -0,0 +1,30 @@ +-- This file is part of the Wire Server implementation. +-- +-- Copyright (C) 2022 Wire Swiss GmbH +-- +-- This program is free software: you can redistribute it and/or modify it under +-- the terms of the GNU Affero General Public License as published by the Free +-- Software Foundation, either version 3 of the License, or (at your option) any +-- later version. +-- +-- This program is distributed in the hope that it will be useful, but WITHOUT +-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +-- details. +-- +-- You should have received a copy of the GNU Affero General Public License along +-- with this program. If not, see . + +module Wire.API.MLS.Group where + +import Imports +import Wire.API.MLS.Serialisation + +newtype GroupId = GroupId {unGroupId :: ByteString} + deriving (Eq, Show) + +instance IsString GroupId where + fromString = GroupId . fromString + +instance ParseMLS GroupId where + parseMLS = GroupId <$> parseMLSBytes @Word8 diff --git a/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs b/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs index 6c1958f281..74f819fc64 100644 --- a/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs +++ b/libs/wire-api/src/Wire/API/MLS/KeyPackage.hs @@ -40,7 +40,6 @@ module Wire.API.MLS.KeyPackage decodeExtension, parseExtension, ExtensionTag (..), - ReservedExtensionTagSym0, CapabilitiesExtensionTagSym0, LifetimeExtensionTagSym0, SExtensionTag (..), @@ -56,7 +55,6 @@ module Wire.API.MLS.KeyPackage where import Control.Applicative -import Control.Error.Util import Control.Lens hiding (set, (.=)) import Data.Aeson (FromJSON, ToJSON) import Data.Binary @@ -156,20 +154,17 @@ instance ParseMLS Extension where parseMLS = Extension <$> parseMLS <*> parseMLSBytes @Word32 data ExtensionTag - = ReservedExtensionTag - | CapabilitiesExtensionTag + = CapabilitiesExtensionTag | LifetimeExtensionTag deriving (Bounded, Enum) $(genSingletons [''ExtensionTag]) type family ExtensionType (t :: ExtensionTag) :: * where - ExtensionType 'ReservedExtensionTag = () ExtensionType 'CapabilitiesExtensionTag = Capabilities ExtensionType 'LifetimeExtensionTag = Lifetime parseExtension :: Sing t -> Get (ExtensionType t) -parseExtension SReservedExtensionTag = pure () parseExtension SCapabilitiesExtensionTag = parseMLS parseExtension SLifetimeExtensionTag = parseMLS @@ -182,16 +177,16 @@ instance Eq SomeExtension where _ == _ = False instance Show SomeExtension where - show (SomeExtension SReservedExtensionTag _) = show () show (SomeExtension SCapabilitiesExtensionTag caps) = show caps show (SomeExtension SLifetimeExtensionTag lt) = show lt -decodeExtension :: Extension -> Maybe SomeExtension +decodeExtension :: Extension -> Either Text (Maybe SomeExtension) decodeExtension e = do - t <- safeToEnum (fromIntegral (extType e)) - hush $ - withSomeSing t $ \st -> - decodeMLSWith' (SomeExtension st <$> parseExtension st) (extData e) + case toMLSEnum' (extType e) of + Left MLSEnumUnkonwn -> pure Nothing + Left MLSEnumInvalid -> Left "Invalid extension type" + Right t -> withSomeSing t $ \st -> + Just <$> decodeMLSWith' (SomeExtension st <$> parseExtension st) (extData e) data Capabilities = Capabilities { capVersions :: [ProtocolVersion], @@ -234,7 +229,7 @@ data KeyPackageTBS = KeyPackageTBS kpCredential :: Credential, kpExtensions :: [Extension] } - deriving stock (Show, Generic) + deriving stock (Eq, Show, Generic) deriving (Arbitrary) via GenericUniform KeyPackageTBS instance ParseMLS KeyPackageTBS where @@ -250,10 +245,13 @@ data KeyPackage = KeyPackage { kpTBS :: KeyPackageTBS, kpSignature :: ByteString } - deriving (Show) + deriving stock (Eq, Show) newtype KeyPackageRef = KeyPackageRef {unKeyPackageRef :: ByteString} - deriving stock (Show) + deriving stock (Eq, Show) + +instance ParseMLS KeyPackageRef where + parseMLS = KeyPackageRef <$> getByteString 16 kpRef :: CipherSuiteTag -> KeyPackageData -> KeyPackageRef kpRef cs = diff --git a/libs/wire-api/src/Wire/API/MLS/Message.hs b/libs/wire-api/src/Wire/API/MLS/Message.hs new file mode 100644 index 0000000000..03d37e407f --- /dev/null +++ b/libs/wire-api/src/Wire/API/MLS/Message.hs @@ -0,0 +1,154 @@ +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} + +-- This file is part of the Wire Server implementation. +-- +-- Copyright (C) 2022 Wire Swiss GmbH +-- +-- This program is free software: you can redistribute it and/or modify it under +-- the terms of the GNU Affero General Public License as published by the Free +-- Software Foundation, either version 3 of the License, or (at your option) any +-- later version. +-- +-- This program is distributed in the hope that it will be useful, but WITHOUT +-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +-- details. +-- +-- You should have received a copy of the GNU Affero General Public License along +-- with this program. If not, see . + +module Wire.API.MLS.Message + ( Message (..), + WireFormatTag (..), + SWireFormatTag (..), + SomeMessage (..), + ContentType (..), + MessagePayload (..), + MessagePayloadTBS (..), + Sender (..), + MLSPlainTextSym0, + MLSCipherTextSym0, + ) +where + +import Data.Binary +import Data.Singletons.TH +import Imports +import Wire.API.MLS.Commit +import Wire.API.MLS.Group +import Wire.API.MLS.KeyPackage +import Wire.API.MLS.Proposal +import Wire.API.MLS.Serialisation + +data WireFormatTag = MLSPlainText | MLSCipherText + deriving (Bounded, Enum, Eq, Show) + +$(genSingletons [''WireFormatTag]) + +instance ParseMLS WireFormatTag where + parseMLS = parseMLSEnum @Word8 "wire format" + +data Message (tag :: WireFormatTag) = Message + { msgGroupId :: GroupId, + msgEpoch :: Word64, + msgAuthData :: ByteString, + msgSender :: Sender tag, + msgPayload :: MessagePayload tag + } + +instance ParseMLS (Message 'MLSPlainText) where + parseMLS = do + g <- parseMLS + e <- parseMLS + s <- parseMLS + d <- parseMLSBytes @Word32 + p <- parseMLS + pure (Message g e d s p) + +instance ParseMLS (Message 'MLSCipherText) where + parseMLS = do + g <- parseMLS + e <- parseMLS + ct <- parseMLS + d <- parseMLSBytes @Word32 + s <- parseMLS + p <- parseMLSBytes @Word32 + pure $ Message g e d s (CipherText ct p) + +data SomeMessage where + SomeMessage :: Sing tag -> Message tag -> SomeMessage + +instance ParseMLS SomeMessage where + parseMLS = + parseMLS >>= \case + MLSPlainText -> SomeMessage SMLSPlainText <$> parseMLS + MLSCipherText -> SomeMessage SMLSCipherText <$> parseMLS + +data family Sender (tag :: WireFormatTag) :: * + +data instance Sender 'MLSCipherText = EncryptedSender {esData :: ByteString} + +instance ParseMLS (Sender 'MLSCipherText) where + parseMLS = EncryptedSender <$> parseMLSBytes @Word8 + +data SenderTag = MemberSenderTag | PreconfiguredSenderTag | NewMemberSenderTag + deriving (Bounded, Enum, Show, Eq) + +instance ParseMLS SenderTag where + parseMLS = parseMLSEnum @Word8 "sender type" + +data instance Sender 'MLSPlainText + = MemberSender KeyPackageRef + | PreconfiguredSender ByteString + | NewMemberSender + +instance ParseMLS (Sender 'MLSPlainText) where + parseMLS = + parseMLS >>= \case + MemberSenderTag -> MemberSender <$> parseMLS + PreconfiguredSenderTag -> PreconfiguredSender <$> parseMLSBytes @Word8 + NewMemberSenderTag -> pure NewMemberSender + +data family MessagePayload (tag :: WireFormatTag) :: * + +data instance MessagePayload 'MLSCipherText = CipherText + { msgContentType :: Word8, + msgCipherText :: ByteString + } + +data instance MessagePayload 'MLSPlainText = MessagePayload + { msgTBS :: MessagePayloadTBS, + msgSignature :: ByteString, + msgConfirmation :: Maybe ByteString, + msgMembership :: Maybe ByteString + } + +instance ParseMLS (MessagePayload 'MLSPlainText) where + parseMLS = + MessagePayload + <$> parseMLS + <*> parseMLSBytes @Word16 + <*> parseMLSOptional (parseMLSBytes @Word8) + <*> parseMLSOptional (parseMLSBytes @Word8) + +data MessagePayloadTBS + = ApplicationMessage ByteString + | ProposalMessage Proposal + | CommitMessage Commit + +data ContentType + = ApplicationMessageTag + | ProposalMessageTag + | CommitMessageTag + deriving (Bounded, Enum, Eq, Show) + +instance ParseMLS ContentType where + parseMLS = parseMLSEnum @Word8 "content type" + +instance ParseMLS MessagePayloadTBS where + parseMLS = + parseMLS >>= \case + ApplicationMessageTag -> ApplicationMessage <$> parseMLSBytes @Word32 + ProposalMessageTag -> ProposalMessage <$> parseMLS + CommitMessageTag -> CommitMessage <$> parseMLS diff --git a/libs/wire-api/src/Wire/API/MLS/Proposal.hs b/libs/wire-api/src/Wire/API/MLS/Proposal.hs index abfc1553eb..801ff69bbf 100644 --- a/libs/wire-api/src/Wire/API/MLS/Proposal.hs +++ b/libs/wire-api/src/Wire/API/MLS/Proposal.hs @@ -18,20 +18,130 @@ module Wire.API.MLS.Proposal where import Data.Binary +import Data.Binary.Get import Imports import Wire.API.Arbitrary +import Wire.API.MLS.CipherSuite +import Wire.API.MLS.Group +import Wire.API.MLS.KeyPackage import Wire.API.MLS.Serialisation -data ProposalType - = AddProposal - | UpdateProposal - | RemoveProposal - | PreSharedKeyProposal - | ReInitProposal - | ExternalInitProposal - | AppAckProposal - | GroupContextExtensionsProposal - | ExternalProposal +data ProposalTag + = AddProposalTag + | UpdateProposalTag + | RemoveProposalTag + | PreSharedKeyProposalTag + | ReInitProposalTag + | ExternalInitProposalTag + | AppAckProposalTag + | GroupContextExtensionsProposalTag deriving stock (Bounded, Enum, Eq, Generic, Show) - deriving (ParseMLS) via (EnumMLS Word16 ProposalType) - deriving (Arbitrary) via GenericUniform ProposalType + deriving (Arbitrary) via GenericUniform ProposalTag + +instance ParseMLS ProposalTag where + parseMLS = parseMLSEnum @Word16 "proposal type" + +data Proposal + = AddProposal KeyPackage + | UpdateProposal KeyPackage + | RemoveProposal KeyPackageRef + | PreSharedKeyProposal PreSharedKeyID + | ReInitProposal ReInit + | ExternalInitProposal ByteString + | AppAckProposal [MessageRange] + | GroupContextExtensionsProposal [Extension] + deriving stock (Eq, Show) + +instance ParseMLS Proposal where + parseMLS = + parseMLS >>= \case + AddProposalTag -> AddProposal <$> parseMLS + UpdateProposalTag -> UpdateProposal <$> parseMLS + RemoveProposalTag -> RemoveProposal <$> parseMLS + PreSharedKeyProposalTag -> PreSharedKeyProposal <$> parseMLS + ReInitProposalTag -> ReInitProposal <$> parseMLS + ExternalInitProposalTag -> ExternalInitProposal <$> parseMLSBytes @Word16 + AppAckProposalTag -> AppAckProposal <$> parseMLSVector @Word32 parseMLS + GroupContextExtensionsProposalTag -> + GroupContextExtensionsProposal <$> parseMLSVector @Word32 parseMLS + +data PreSharedKeyTag = ExternalKeyTag | ResumptionKeyTag + deriving (Bounded, Enum, Eq, Show) + +instance ParseMLS PreSharedKeyTag where + parseMLS = parseMLSEnum @Word16 "PreSharedKeyID type" + +data PreSharedKeyID = ExternalKeyID ByteString | ResumptionKeyID Resumption + deriving stock (Eq, Show) + +instance ParseMLS PreSharedKeyID where + parseMLS = do + t <- parseMLS + case t of + ExternalKeyTag -> ExternalKeyID <$> parseMLSBytes @Word8 + ResumptionKeyTag -> ResumptionKeyID <$> parseMLS + +data Resumption = Resumption + { resUsage :: Word8, + resGroupId :: GroupId, + resEpoch :: Word64 + } + deriving stock (Eq, Show) + +instance ParseMLS Resumption where + parseMLS = + Resumption + <$> parseMLS + <*> parseMLS + <*> parseMLS + +data ReInit = ReInit + { riGroupId :: GroupId, + riProtocolVersion :: ProtocolVersion, + riCipherSuite :: CipherSuite, + riExtensions :: [Extension] + } + deriving stock (Eq, Show) + +instance ParseMLS ReInit where + parseMLS = + ReInit + <$> parseMLS + <*> parseMLS + <*> parseMLS + <*> parseMLSVector @Word32 parseMLS + +data MessageRange = MessageRange + { mrSender :: KeyPackageRef, + mrFirstGeneration :: Word32, + mrLastGenereation :: Word32 + } + deriving stock (Eq, Show) + +instance ParseMLS MessageRange where + parseMLS = + MessageRange + <$> parseMLS + <*> parseMLS + <*> parseMLS + +data ProposalOrRefTag = InlineTag | RefTag + deriving stock (Bounded, Enum, Eq, Show) + +instance ParseMLS ProposalOrRefTag where + parseMLS = parseMLSEnum @Word8 "ProposalOrRef type" + +data ProposalOrRef = Inline Proposal | Ref ProposalRef + deriving stock (Eq, Show) + +instance ParseMLS ProposalOrRef where + parseMLS = + parseMLS >>= \case + InlineTag -> Inline <$> parseMLS + RefTag -> Ref <$> parseMLS + +newtype ProposalRef = ProposalRef {unProposalRef :: ByteString} + deriving stock (Eq, Show) + +instance ParseMLS ProposalRef where + parseMLS = ProposalRef <$> getByteString 16 diff --git a/libs/wire-api/src/Wire/API/MLS/Serialisation.hs b/libs/wire-api/src/Wire/API/MLS/Serialisation.hs index f773e8fa2e..cc6db424ef 100644 --- a/libs/wire-api/src/Wire/API/MLS/Serialisation.hs +++ b/libs/wire-api/src/Wire/API/MLS/Serialisation.hs @@ -19,9 +19,13 @@ module Wire.API.MLS.Serialisation ( ParseMLS (..), parseMLSVector, parseMLSBytes, + parseMLSOptional, + parseMLSEnum, BinaryMLS (..), - EnumMLS (..), - safeToEnum, + MLSEnumError (..), + fromMLSEnum, + toMLSEnum', + toMLSEnum, decodeMLS, decodeMLS', decodeMLSWith, @@ -58,6 +62,40 @@ parseMLSBytes = do len <- fromIntegral <$> get @w getByteString len +parseMLSOptional :: Get a -> Get (Maybe a) +parseMLSOptional g = do + b <- getWord8 + sequenceA $ guard (b /= 0) $> g + +-- | Parse a positive tag for an enumeration. The value 0 is considered +-- "reserved", and all other values are shifted down by 1 to get the +-- corresponding enumeration index. This makes it possible to parse enumeration +-- types that don't contain an explicit constructor for a "reserved" value. +parseMLSEnum :: + forall (w :: *) a. + (Bounded a, Enum a, Integral w, Binary w) => + String -> + Get a +parseMLSEnum name = toMLSEnum name =<< get @w + +data MLSEnumError = MLSEnumUnkonwn | MLSEnumInvalid + +toMLSEnum' :: forall a w. (Bounded a, Enum a, Integral w) => w -> Either MLSEnumError a +toMLSEnum' w = case fromIntegral w - 1 of + n + | n < 0 -> Left MLSEnumInvalid + | n < fromEnum @a minBound || n > fromEnum @a maxBound -> Left MLSEnumUnkonwn + | otherwise -> pure (toEnum n) + +toMLSEnum :: forall a w f. (Bounded a, Enum a, MonadFail f, Integral w) => String -> w -> f a +toMLSEnum name = either err pure . toMLSEnum' + where + err MLSEnumUnkonwn = fail $ "Unknown " <> name + err MLSEnumInvalid = fail $ "Invalid " <> name + +fromMLSEnum :: (Integral w, Enum a) => a -> w +fromMLSEnum = fromIntegral . succ . fromEnum + instance ParseMLS Word8 where parseMLS = get instance ParseMLS Word16 where parseMLS = get @@ -72,21 +110,6 @@ newtype BinaryMLS a = BinaryMLS a instance Binary a => ParseMLS (BinaryMLS a) where parseMLS = BinaryMLS <$> get --- | A wrapper to generate a 'Binary' instance for an enumerated type. -newtype EnumMLS w a = EnumMLS {unEnumMLS :: a} - -safeToEnum :: forall a f. (Bounded a, Enum a, MonadFail f) => Int -> f a -safeToEnum n - | n >= fromEnum @a minBound && n <= fromEnum @a maxBound = - pure (toEnum n) - | otherwise = - fail "Out of bound enumeration" - -instance (Binary w, Integral w, Bounded a, Enum a) => ParseMLS (EnumMLS w a) where - parseMLS = do - n <- fromIntegral <$> get @w - EnumMLS <$> safeToEnum n - -- | Decode an MLS value from a lazy bytestring. Return an error message in case of failure. decodeMLS :: ParseMLS a => LByteString -> Either Text a decodeMLS = decodeMLSWith parseMLS diff --git a/libs/wire-api/src/Wire/API/MLS/Welcome.hs b/libs/wire-api/src/Wire/API/MLS/Welcome.hs new file mode 100644 index 0000000000..76166969f4 --- /dev/null +++ b/libs/wire-api/src/Wire/API/MLS/Welcome.hs @@ -0,0 +1,47 @@ +-- This file is part of the Wire Server implementation. +-- +-- Copyright (C) 2022 Wire Swiss GmbH +-- +-- This program is free software: you can redistribute it and/or modify it under +-- the terms of the GNU Affero General Public License as published by the Free +-- Software Foundation, either version 3 of the License, or (at your option) any +-- later version. +-- +-- This program is distributed in the hope that it will be useful, but WITHOUT +-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +-- details. +-- +-- You should have received a copy of the GNU Affero General Public License along +-- with this program. If not, see . + +module Wire.API.MLS.Welcome where + +import Imports +import Wire.API.MLS.CipherSuite +import Wire.API.MLS.Commit +import Wire.API.MLS.KeyPackage +import Wire.API.MLS.Serialisation + +data Welcome = Welcome + { welCipherSuite :: CipherSuite, + welSecrets :: [GroupSecrets], + welGroupInfo :: ByteString + } + +instance ParseMLS Welcome where + parseMLS = + Welcome + -- Note: the extra protocol version at the beginning of the welcome + -- message is present in openmls-0.4.0-pre, but is not part of the spec + <$> (parseMLS @ProtocolVersion *> parseMLS) + <*> parseMLSVector @Word32 parseMLS + <*> parseMLSBytes @Word32 + +data GroupSecrets = GroupSecrets + { gsNewMember :: KeyPackageRef, + gsSecrets :: HPKECiphertext + } + +instance ParseMLS GroupSecrets where + parseMLS = GroupSecrets <$> parseMLS <*> parseMLS diff --git a/libs/wire-api/test/resources/app_message1.mls b/libs/wire-api/test/resources/app_message1.mls new file mode 100644 index 0000000000..0426a5a98f Binary files /dev/null and b/libs/wire-api/test/resources/app_message1.mls differ diff --git a/libs/wire-api/test/resources/commit1.mls b/libs/wire-api/test/resources/commit1.mls new file mode 100644 index 0000000000..c8f40b1bf1 Binary files /dev/null and b/libs/wire-api/test/resources/commit1.mls differ diff --git a/libs/wire-api/test/resources/welcome1.mls b/libs/wire-api/test/resources/welcome1.mls new file mode 100644 index 0000000000..0a628b1097 Binary files /dev/null and b/libs/wire-api/test/resources/welcome1.mls differ diff --git a/libs/wire-api/test/unit/Test/Wire/API/MLS.hs b/libs/wire-api/test/unit/Test/Wire/API/MLS.hs index 662b94662a..9302b91079 100644 --- a/libs/wire-api/test/unit/Test/Wire/API/MLS.hs +++ b/libs/wire-api/test/unit/Test/Wire/API/MLS.hs @@ -20,6 +20,8 @@ module Test.Wire.API.MLS where import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy as LBS import Data.Domain +import Data.Either.Combinators +import Data.Hex import Data.Id import qualified Data.Text as T import qualified Data.UUID as UUID @@ -27,14 +29,21 @@ import Imports import Test.Tasty import Test.Tasty.HUnit import Wire.API.MLS.CipherSuite +import Wire.API.MLS.Commit import Wire.API.MLS.Credential import Wire.API.MLS.KeyPackage +import Wire.API.MLS.Message +import Wire.API.MLS.Proposal import Wire.API.MLS.Serialisation +import Wire.API.MLS.Welcome tests :: TestTree tests = testGroup "MLS" $ - [ testCase "parse key packages" testParseKeyPackage + [ testCase "parse key package" testParseKeyPackage, + testCase "parse commit message" testParseCommit, + testCase "parse application message" testParseApplication, + testCase "parse welcome message" testParseWelcome ] testParseKeyPackage :: IO () @@ -55,3 +64,52 @@ testParseKeyPackage = do ciUser = Id (fromJust (UUID.fromString "b455a431-9db6-4404-86e7-6a3ebe73fcaf")), ciClient = newClientId 0x3ae58155 } + +testParseCommit :: IO () +testParseCommit = do + msgData <- LBS.readFile "test/resources/commit1.mls" + msg :: Message 'MLSPlainText <- case decodeMLS @SomeMessage msgData of + Left err -> assertFailure (T.unpack err) + Right (SomeMessage SMLSCipherText _) -> + assertFailure "Expected plain text message, found encrypted" + Right (SomeMessage SMLSPlainText msg) -> + pure msg + + msgGroupId msg @?= "test_group" + msgEpoch msg @?= 0 + + case msgSender msg of + MemberSender kp -> kp @?= KeyPackageRef (fromRight' (unhex "24e4b0a802a2b81f00a9af7df5e91da8")) + _ -> assertFailure "Unexpected sender type" + + let payload = msgPayload msg + commit <- case msgTBS payload of + CommitMessage c -> pure c + _ -> assertFailure "Unexpected message type" + + case cProposals commit of + [Inline (AddProposal _)] -> pure () + _ -> assertFailure "Unexpected proposals" + +testParseApplication :: IO () +testParseApplication = do + msgData <- LBS.readFile "test/resources/app_message1.mls" + msg :: Message 'MLSCipherText <- case decodeMLS @SomeMessage msgData of + Left err -> assertFailure (T.unpack err) + Right (SomeMessage SMLSCipherText msg) -> pure msg + Right (SomeMessage SMLSPlainText _) -> + assertFailure "Expected encrypted message, found plain text" + + msgGroupId msg @?= "test_group" + msgEpoch msg @?= 0 + msgContentType (msgPayload msg) @?= fromMLSEnum ApplicationMessageTag + +testParseWelcome :: IO () +testParseWelcome = do + welData <- LBS.readFile "test/resources/welcome1.mls" + wel <- case decodeMLS welData of + Left err -> assertFailure (T.unpack err) + Right x -> pure x + + welCipherSuite wel @?= CipherSuite 1 + map gsNewMember (welSecrets wel) @?= [KeyPackageRef (fromRight' (unhex "ab4692703ca6d50ffdeaae3096f885c2"))] diff --git a/libs/wire-api/wire-api.cabal b/libs/wire-api/wire-api.cabal index 8fd943c89c..bee038d36e 100644 --- a/libs/wire-api/wire-api.cabal +++ b/libs/wire-api/wire-api.cabal @@ -37,11 +37,15 @@ library Wire.API.Message Wire.API.Message.Proto Wire.API.MLS.CipherSuite + Wire.API.MLS.Commit Wire.API.MLS.Credential + Wire.API.MLS.Group Wire.API.MLS.KeyPackage + Wire.API.MLS.Message Wire.API.MLS.Proposal Wire.API.MLS.Serialisation Wire.API.MLS.Servant + Wire.API.MLS.Welcome Wire.API.Notification Wire.API.Properties Wire.API.Provider @@ -643,6 +647,7 @@ test-suite wire-api-tests , containers >=0.5 , currency-codes , directory + , either , filepath , hex , hscim diff --git a/services/brig/src/Brig/API/MLS/KeyPackages/Validation.hs b/services/brig/src/Brig/API/MLS/KeyPackages/Validation.hs index cb9b600ee9..83a9f1e94e 100644 --- a/services/brig/src/Brig/API/MLS/KeyPackages/Validation.hs +++ b/services/brig/src/Brig/API/MLS/KeyPackages/Validation.hs @@ -107,10 +107,9 @@ findExtensions :: [Extension] -> Either Text (RequiredExtensions Identity) findExtensions = (checkRequiredExtensions =<<) . getAp . foldMap findExtension findExtension :: Extension -> Ap (Either Text) (RequiredExtensions Maybe) -findExtension ext = flip foldMap (decodeExtension ext) $ \case +findExtension ext = (Ap (decodeExtension ext) >>=) . foldMap $ \case (SomeExtension SLifetimeExtensionTag lt) -> pure $ RequiredExtensions (Just lt) Nothing (SomeExtension SCapabilitiesExtensionTag _) -> pure $ RequiredExtensions Nothing (Just ()) - _ -> Ap (Left "Invalid extension") validateExtensions :: [Extension] -> Handler r () validateExtensions exts = do diff --git a/services/brig/test/unit/Test/Brig/MLS.hs b/services/brig/test/unit/Test/Brig/MLS.hs index f38e8e4852..de7dc37bd8 100644 --- a/services/brig/test/unit/Test/Brig/MLS.hs +++ b/services/brig/test/unit/Test/Brig/MLS.hs @@ -28,6 +28,7 @@ import Test.Tasty import Test.Tasty.QuickCheck import Wire.API.MLS.CipherSuite import Wire.API.MLS.KeyPackage +import Wire.API.MLS.Serialisation -- | A lifetime with a length of at least 1 day. newtype ValidLifetime = ValidLifetime Lifetime @@ -61,14 +62,20 @@ newtype ValidExtensions = ValidExtensions [Extension] instance Show ValidExtensions where show (ValidExtensions exts) = "ValidExtensions (length " <> show (length exts) <> ")" +unknownExt :: Gen Extension +unknownExt = do + Positive t0 <- arbitrary + let t = t0 + fromEnum (maxBound :: ExtensionTag) + 1 + Extension (fromIntegral t) <$> arbitrary + -- | Generate a list of extensions containing all the required ones. instance Arbitrary ValidExtensions where arbitrary = do - exts0 <- listOf (arbitrary `suchThat` ((/= 0) . extType)) + exts0 <- listOf unknownExt LifetimeAndExtension ext1 _ <- arbitrary - exts2 <- listOf (arbitrary `suchThat` ((/= 0) . extType)) + exts2 <- listOf unknownExt CapabilitiesAndExtension ext3 _ <- arbitrary - exts4 <- listOf (arbitrary `suchThat` ((/= 0) . extType)) + exts4 <- listOf unknownExt pure . ValidExtensions $ exts0 <> [ext1] <> exts2 <> [ext3] <> exts4 newtype InvalidExtensions = InvalidExtensions [Extension] @@ -79,7 +86,7 @@ instance Show InvalidExtensions where instance Arbitrary InvalidExtensions where arbitrary = do - req <- fromIntegral . fromEnum <$> elements [LifetimeExtensionTag, CapabilitiesExtensionTag] + req <- fromMLSEnum <$> elements [LifetimeExtensionTag, CapabilitiesExtensionTag] InvalidExtensions <$> listOf (arbitrary `suchThat` ((/= req) . extType)) data LifetimeAndExtension = LifetimeAndExtension Extension Lifetime @@ -88,7 +95,7 @@ data LifetimeAndExtension = LifetimeAndExtension Extension Lifetime instance Arbitrary LifetimeAndExtension where arbitrary = do lt <- arbitrary - let ext = Extension (fromIntegral (fromEnum LifetimeExtensionTag)) . LBS.toStrict . runPut $ do + let ext = Extension (fromIntegral (fromEnum LifetimeExtensionTag + 1)) . LBS.toStrict . runPut $ do put (timestampSeconds (ltNotBefore lt)) put (timestampSeconds (ltNotAfter lt)) pure $ LifetimeAndExtension ext lt @@ -99,7 +106,7 @@ data CapabilitiesAndExtension = CapabilitiesAndExtension Extension Capabilities instance Arbitrary CapabilitiesAndExtension where arbitrary = do caps <- arbitrary - let ext = Extension (fromIntegral (fromEnum CapabilitiesExtensionTag)) . LBS.toStrict . runPut $ do + let ext = Extension (fromIntegral (fromEnum CapabilitiesExtensionTag + 1)) . LBS.toStrict . runPut $ do putWord8 (fromIntegral (length (capVersions caps))) traverse_ (putWord8 . pvNumber) (capVersions caps) @@ -143,8 +150,8 @@ tests = testProperty "missing required extensions" $ \(InvalidExtensions exts) -> isLeft (findExtensions exts), testProperty "lifetime extension" $ \(LifetimeAndExtension ext lt) -> - decodeExtension ext == Just (SomeExtension SLifetimeExtensionTag lt), + decodeExtension ext == Right (Just (SomeExtension SLifetimeExtensionTag lt)), testProperty "capabilities extension" $ \(CapabilitiesAndExtension ext caps) -> - decodeExtension ext == Just (SomeExtension SCapabilitiesExtensionTag caps) + decodeExtension ext == Right (Just (SomeExtension SCapabilitiesExtensionTag caps)) ] ]