diff --git a/cassandra-schema.cql b/cassandra-schema.cql index ce8798ca5d..a053e0551c 100644 --- a/cassandra-schema.cql +++ b/cassandra-schema.cql @@ -477,6 +477,26 @@ CREATE TABLE galley_test.team_conv ( AND read_repair_chance = 0.0 AND speculative_retry = '99PERCENTILE'; +CREATE TABLE galley_test.mls_commit_locks ( + group_id blob, + epoch bigint, + PRIMARY KEY (group_id, epoch) +) WITH CLUSTERING ORDER BY (epoch ASC) + AND bloom_filter_fp_chance = 0.01 + AND caching = {'keys': 'ALL', 'rows_per_partition': 'NONE'} + AND comment = '' + AND compaction = {'class': 'org.apache.cassandra.db.compaction.SizeTieredCompactionStrategy', 'max_threshold': '32', 'min_threshold': '4'} + AND compression = {'chunk_length_in_kb': '64', 'class': 'org.apache.cassandra.io.compress.LZ4Compressor'} + AND crc_check_chance = 1.0 + AND dclocal_read_repair_chance = 0.1 + AND default_time_to_live = 0 + AND gc_grace_seconds = 864000 + AND max_index_interval = 2048 + AND memtable_flush_period_in_ms = 0 + AND min_index_interval = 128 + AND read_repair_chance = 0.0 + AND speculative_retry = '99PERCENTILE'; + CREATE TABLE galley_test.team ( team uuid PRIMARY KEY, binding boolean, diff --git a/changelog.d/2-features/atomic-commits b/changelog.d/2-features/atomic-commits new file mode 100644 index 0000000000..438c76456d --- /dev/null +++ b/changelog.d/2-features/atomic-commits @@ -0,0 +1 @@ +Prevent race conditions in concurrent MLS commit requests. \ No newline at end of file diff --git a/services/galley/galley.cabal b/services/galley/galley.cabal index 5e1abe32f2..bfa1038b0d 100644 --- a/services/galley/galley.cabal +++ b/services/galley/galley.cabal @@ -56,6 +56,7 @@ library Galley.Cassandra.Code Galley.Cassandra.Conversation Galley.Cassandra.Conversation.Members + Galley.Cassandra.Conversation.MLS Galley.Cassandra.ConversationList Galley.Cassandra.CustomBackend Galley.Cassandra.Instances @@ -625,6 +626,7 @@ executable galley-schema V65_MLSRemoteClients V66_AddSplashScreen V67_MLSFeature + V68_MLSCommitLock Paths_galley hs-source-dirs: schema/src diff --git a/services/galley/schema/src/Main.hs b/services/galley/schema/src/Main.hs index 1ae83a6ff7..2239be7cef 100644 --- a/services/galley/schema/src/Main.hs +++ b/services/galley/schema/src/Main.hs @@ -70,6 +70,7 @@ import qualified V64_Epoch import qualified V65_MLSRemoteClients import qualified V66_AddSplashScreen import qualified V67_MLSFeature +import qualified V68_MLSCommitLock main :: IO () main = do @@ -125,7 +126,8 @@ main = do V64_Epoch.migration, V65_MLSRemoteClients.migration, V66_AddSplashScreen.migration, - V67_MLSFeature.migration + V67_MLSFeature.migration, + V68_MLSCommitLock.migration -- When adding migrations here, don't forget to update -- 'schemaVersion' in Galley.Cassandra -- (see also docs/developer/cassandra-interaction.md) diff --git a/services/galley/schema/src/V68_MLSCommitLock.hs b/services/galley/schema/src/V68_MLSCommitLock.hs new file mode 100644 index 0000000000..33edb23673 --- /dev/null +++ b/services/galley/schema/src/V68_MLSCommitLock.hs @@ -0,0 +1,33 @@ +-- This file is part of the Wire Server implementation. +-- +-- Copyright (C) 2022 Wire Swiss GmbH +-- +-- This program is free software: you can redistribute it and/or modify it under +-- the terms of the GNU Affero General Public License as published by the Free +-- Software Foundation, either version 3 of the License, or (at your option) any +-- later version. +-- +-- This program is distributed in the hope that it will be useful, but WITHOUT +-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +-- details. +-- +-- You should have received a copy of the GNU Affero General Public License along +-- with this program. If not, see . + +module V68_MLSCommitLock where + +import Cassandra.Schema +import Imports +import Text.RawString.QQ + +migration :: Migration +migration = + Migration 68 "Add lock table for MLS commits" $ + schema' + [r| CREATE TABLE mls_commit_locks ( + group_id blob, + epoch bigint, + PRIMARY KEY (group_id, epoch) + ) + |] diff --git a/services/galley/src/Galley/API/MLS/Message.hs b/services/galley/src/Galley/API/MLS/Message.hs index 1d62914d24..da8cd90a58 100644 --- a/services/galley/src/Galley/API/MLS/Message.hs +++ b/services/galley/src/Galley/API/MLS/Message.hs @@ -36,6 +36,7 @@ import Galley.API.Util import Galley.Data.Conversation.Types hiding (Conversation) import qualified Galley.Data.Conversation.Types as Data import Galley.Data.Services +import Galley.Data.Types import Galley.Effects import Galley.Effects.BrigAccess import Galley.Effects.ConversationStore @@ -49,6 +50,7 @@ import Polysemy import Polysemy.Error import Polysemy.Input import Polysemy.Internal +import Polysemy.Resource (Resource, bracket) import Polysemy.TinyLog import qualified System.Logger.Class as Logger import Wire.API.Conversation.Protocol @@ -63,6 +65,7 @@ import Wire.API.Federation.Error import Wire.API.MLS.CipherSuite import Wire.API.MLS.Commit import Wire.API.MLS.Credential +import Wire.API.MLS.Group import Wire.API.MLS.KeyPackage import Wire.API.MLS.Message import Wire.API.MLS.Proposal @@ -72,7 +75,8 @@ import Wire.API.Message postMLSMessage :: ( HasProposalEffects r, Members - '[ Error FederationError, + '[ Resource, + Error FederationError, ErrorS 'ConvNotFound, Error InternalError, ErrorS 'MLSUnsupportedMessage, @@ -154,7 +158,8 @@ processCommit :: Member (ErrorS 'MLSProposalNotFound) r, Member (Error FederationError) r, Member (Error InternalError) r, - Member (ErrorS 'MissingLegalholdConsent) r + Member (ErrorS 'MissingLegalholdConsent) r, + Member Resource r ) => Local UserId -> ConnId -> @@ -167,34 +172,40 @@ processCommit lusr con conv epoch sender commit = do self <- noteS @'ConvNotFound $ getConvMember lusr conv lusr -- check epoch number - curEpoch <- - preview (to convProtocol . _ProtocolMLS . to cnvmlsEpoch) conv + convMeta <- + preview (to convProtocol . _ProtocolMLS) conv & noteS @'ConvNotFound + + let curEpoch = cnvmlsEpoch convMeta + groupId = cnvmlsGroupId convMeta + when (epoch /= curEpoch) $ throwS @'MLSStaleMessage - when (epoch == Epoch 0) $ do - -- this is a newly created conversation, and it should contain exactly one - -- client (the creator) - case (sender, toList (lmMLSClients self)) of - (MemberSender ref, [creatorClient]) -> do - -- register the creator client - addKeyPackageRef - ref - (qUntagged lusr) - creatorClient - (qUntagged (qualifyAs lusr (Data.convId conv))) - (MemberSender _, _) -> - throw (InternalErrorWithDescription "Unexpected creator client set") - _ -> throw (mlsProtocolError "Unexpected sender") - - -- process and execute proposals - action <- foldMap applyProposalRef (cProposals commit) - events <- executeProposalAction lusr con conv action - - -- increment epoch number - setConversationEpoch (Data.convId conv) (succ epoch) - - pure events + let ttlSeconds :: Int = 600 -- 10 minutes + withCommitLock groupId epoch (fromIntegral ttlSeconds) $ do + when (epoch == Epoch 0) $ do + -- this is a newly created conversation, and it should contain exactly one + -- client (the creator) + case (sender, toList (lmMLSClients self)) of + (MemberSender ref, [creatorClient]) -> do + -- register the creator client + addKeyPackageRef + ref + (qUntagged lusr) + creatorClient + (qUntagged (qualifyAs lusr (Data.convId conv))) + (MemberSender _, _) -> + throw (InternalErrorWithDescription "Unexpected creator client set") + _ -> throw (mlsProtocolError "Unexpected sender") + + -- process and execute proposals + action <- foldMap applyProposalRef (cProposals commit) + events <- executeProposalAction lusr con conv action + + -- increment epoch number + setConversationEpoch (Data.convId conv) (succ epoch) + + pure events applyProposalRef :: ( HasProposalEffects r, @@ -414,3 +425,26 @@ instance HandleMLSProposalFailure (Error e) r where handleMLSProposalFailure = mapError (MLSProposalFailure . toWai) + +withCommitLock :: + forall r a. + ( Members + '[ Resource, + ConversationStore, + ErrorS 'MLSStaleMessage + ] + r + ) => + GroupId -> + Epoch -> + NominalDiffTime -> + Sem r a -> + Sem r a +withCommitLock gid epoch ttl action = + bracket + ( acquireCommitLock gid epoch ttl >>= \lockAcquired -> + when (lockAcquired == NotAcquired) $ + throwS @'MLSStaleMessage + ) + (const $ releaseCommitLock gid epoch) + (const action) diff --git a/services/galley/src/Galley/App.hs b/services/galley/src/Galley/App.hs index 868bd15e3d..f5ec702643 100644 --- a/services/galley/src/Galley/App.hs +++ b/services/galley/src/Galley/App.hs @@ -93,6 +93,7 @@ import Polysemy import Polysemy.Error import Polysemy.Input import Polysemy.Internal (Append) +import Polysemy.Resource (Resource, runResource) import qualified Polysemy.TinyLog as P import qualified Servant import Ssl.Util @@ -114,6 +115,7 @@ type GalleyEffects0 = -- federation errors can be thrown by almost every endpoint, so we avoid -- having to declare it every single time, and simply handle it here Error FederationError, + Resource, Embed IO, Final IO ] @@ -226,6 +228,7 @@ evalGalley :: Env -> Sem GalleyEffects a -> IO a evalGalley e = runFinal @IO . embedToFinal @IO + . runResource . interpretWaiErrorToException . interpretWaiErrorToException . interpretWaiErrorToException diff --git a/services/galley/src/Galley/Cassandra.hs b/services/galley/src/Galley/Cassandra.hs index 75a2a880c5..e1250c5b4a 100644 --- a/services/galley/src/Galley/Cassandra.hs +++ b/services/galley/src/Galley/Cassandra.hs @@ -20,4 +20,4 @@ module Galley.Cassandra (schemaVersion) where import Imports schemaVersion :: Int32 -schemaVersion = 67 +schemaVersion = 68 diff --git a/services/galley/src/Galley/Cassandra/Conversation.hs b/services/galley/src/Galley/Cassandra/Conversation.hs index d806b2e492..95e94c427e 100644 --- a/services/galley/src/Galley/Cassandra/Conversation.hs +++ b/services/galley/src/Galley/Cassandra/Conversation.hs @@ -35,6 +35,7 @@ import Data.Range import qualified Data.Set as Set import Data.UUID.V4 (nextRandom) import Galley.Cassandra.Access +import Galley.Cassandra.Conversation.MLS import Galley.Cassandra.Conversation.Members import qualified Galley.Cassandra.Queries as Cql import Galley.Cassandra.Store @@ -296,3 +297,5 @@ interpretConversationStoreToCassandra = interpret $ \case SetConversationEpoch cid epoch -> embedClient $ updateConvEpoch cid epoch DeleteConversation cid -> embedClient $ deleteConversation cid SetGroupId gId cid -> embedClient $ mapGroupId gId cid + AcquireCommitLock gId epoch ttl -> embedClient $ acquireCommitLock gId epoch ttl + ReleaseCommitLock gId epoch -> embedClient $ releaseCommitLock gId epoch diff --git a/services/galley/src/Galley/Cassandra/Conversation/MLS.hs b/services/galley/src/Galley/Cassandra/Conversation/MLS.hs new file mode 100644 index 0000000000..607e386609 --- /dev/null +++ b/services/galley/src/Galley/Cassandra/Conversation/MLS.hs @@ -0,0 +1,56 @@ +-- This file is part of the Wire Server implementation. +-- +-- Copyright (C) 2022 Wire Swiss GmbH +-- +-- This program is free software: you can redistribute it and/or modify it under +-- the terms of the GNU Affero General Public License as published by the Free +-- Software Foundation, either version 3 of the License, or (at your option) any +-- later version. +-- +-- This program is distributed in the hope that it will be useful, but WITHOUT +-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more +-- details. +-- +-- You should have received a copy of the GNU Affero General Public License along +-- with this program. If not, see . + +module Galley.Cassandra.Conversation.MLS where + +import Cassandra +import Cassandra.Settings (fromRow) +import Data.Time +import qualified Galley.Cassandra.Queries as Cql +import Galley.Data.Types +import Imports +import Wire.API.MLS.Group +import Wire.API.MLS.Message + +acquireCommitLock :: GroupId -> Epoch -> NominalDiffTime -> Client LockAcquired +acquireCommitLock groupId epoch ttl = do + rows <- + retry x5 $ + trans + Cql.acquireCommitLock + ( params + LocalQuorum + (groupId, epoch, round ttl) + ) + pure $ + if checkTransSuccess rows + then Acquired + else NotAcquired + +releaseCommitLock :: GroupId -> Epoch -> Client () +releaseCommitLock groupId epoch = + retry x5 $ + write + Cql.releaseCommitLock + ( params + LocalQuorum + (groupId, epoch) + ) + +checkTransSuccess :: [Row] -> Bool +checkTransSuccess [] = False +checkTransSuccess (row : _) = either (const False) (fromMaybe False) $ fromRow 0 row diff --git a/services/galley/src/Galley/Cassandra/Queries.hs b/services/galley/src/Galley/Cassandra/Queries.hs index 9c55029465..b955d30573 100644 --- a/services/galley/src/Galley/Cassandra/Queries.hs +++ b/services/galley/src/Galley/Cassandra/Queries.hs @@ -370,6 +370,12 @@ addLocalMLSClients = "update member set mls_clients = mls_clients + ? where conv addRemoteMLSClients :: PrepQuery W (C.Set ClientId, ConvId, Domain, UserId) () addRemoteMLSClients = "update member_remote_user set mls_clients = mls_clients + ? where conv = ? and user_remote_domain = ? and user_remote_id = ?" +acquireCommitLock :: PrepQuery W (GroupId, Epoch, Int32) Row +acquireCommitLock = "insert into mls_commit_locks (group_id, epoch) values (?, ?) if not exists using ttl ?" + +releaseCommitLock :: PrepQuery W (GroupId, Epoch) () +releaseCommitLock = "delete from mls_commit_locks where group_id = ? and epoch = ?" + -- Services ----------------------------------------------------------------- rmSrv :: PrepQuery W (ProviderId, ServiceId) () diff --git a/services/galley/src/Galley/Data/Types.hs b/services/galley/src/Galley/Data/Types.hs index ac485d2bd0..a314af11db 100644 --- a/services/galley/src/Galley/Data/Types.hs +++ b/services/galley/src/Galley/Data/Types.hs @@ -26,6 +26,7 @@ module Galley.Data.Types toCode, generate, mkKey, + LockAcquired (..), ) where @@ -87,3 +88,8 @@ mkKey :: MonadIO m => ConvId -> m Key mkKey cnv = do sha256 <- liftIO $ fromJust <$> getDigestByName "SHA256" pure $ Key . unsafeRange . Ascii.encodeBase64Url . BS.take 15 $ digestBS sha256 (toByteString' cnv) + +data LockAcquired + = Acquired + | NotAcquired + deriving (Show, Eq) diff --git a/services/galley/src/Galley/Effects/ConversationStore.hs b/services/galley/src/Galley/Effects/ConversationStore.hs index aa4864b845..0765a3f71f 100644 --- a/services/galley/src/Galley/Effects/ConversationStore.hs +++ b/services/galley/src/Galley/Effects/ConversationStore.hs @@ -46,6 +46,10 @@ module Galley.Effects.ConversationStore -- * Delete conversation deleteConversation, + + -- * MLS commit lock management + acquireCommitLock, + releaseCommitLock, ) where @@ -53,7 +57,9 @@ import Data.Id import Data.Misc import Data.Qualified import Data.Range +import Data.Time (NominalDiffTime) import Galley.Data.Conversation +import Galley.Data.Types import Galley.Types.Conversations.Members import Imports import Polysemy @@ -81,6 +87,8 @@ data ConversationStore m a where SetConversationMessageTimer :: ConvId -> Maybe Milliseconds -> ConversationStore m () SetConversationEpoch :: ConvId -> Epoch -> ConversationStore m () SetGroupId :: GroupId -> Qualified ConvId -> ConversationStore m () + AcquireCommitLock :: GroupId -> Epoch -> NominalDiffTime -> ConversationStore m LockAcquired + ReleaseCommitLock :: GroupId -> Epoch -> ConversationStore m () makeSem ''ConversationStore diff --git a/services/galley/test/integration/API/MLS.hs b/services/galley/test/integration/API/MLS.hs index 5376473759..bc902c42c1 100644 --- a/services/galley/test/integration/API/MLS.hs +++ b/services/galley/test/integration/API/MLS.hs @@ -23,6 +23,7 @@ import API.MLS.Util import API.Util import Bilge hiding (head) import Bilge.Assert +import Cassandra import Control.Lens (view) import qualified Data.Aeson as Aeson import Data.Default @@ -53,6 +54,8 @@ import Wire.API.Conversation.Role import Wire.API.Event.Conversation import Wire.API.Federation.API.Common import Wire.API.Federation.API.Galley +import Wire.API.MLS.Group (convToGroupId) +import Wire.API.MLS.Message import Wire.API.Message tests :: IO TestSetup -> TestTree @@ -78,7 +81,8 @@ tests s = test s "add user with some non-MLS clients" testAddUserWithProteusClients, test s "add new client of an already-present user to a conversation" testAddNewClient, test s "send a stale commit" testStaleCommit, - test s "add remote user to a conversation" testAddRemoteUser + test s "add remote user to a conversation" testAddRemoteUser, + test s "return error when commit is locked" testCommitLock ], testGroup "Application Message" @@ -462,6 +466,71 @@ testAddRemoteUser = do roleNameWireMember event +testCommitLock :: TestM () +testCommitLock = withSystemTempDirectory "mls" $ \tmp -> do + -- create MLS conversation + (creator, users) <- withLastPrekeys $ setupParticipants tmp def ((,LocalUser) <$> [2, 2, 2]) + conversation <- setupGroup tmp CreateConv creator "group" + let (users1, usersX) = splitAt 1 users + let (users2, users3) = splitAt 1 usersX + void $ assertOne users1 + void $ assertOne users2 + void $ assertOne users3 + + -- initial user can be added + do + (commit, welcome) <- + liftIO $ + setupCommit tmp creator "group" "group" $ + users1 >>= toList . pClients + testSuccessfulCommit MessagingSetup {users = users1, ..} + + -- can commit without blocking + do + (commit, welcome) <- + liftIO $ + setupCommit tmp creator "group" "group" $ + users2 >>= toList . pClients + testSuccessfulCommit MessagingSetup {users = users2, ..} + + -- block epoch + casClient <- view tsCass + runClient casClient $ insertLock (convToGroupId (qTagUnsafe conversation)) (Epoch 2) + + -- commit fails due to competing lock + do + (commit, welcome) <- + liftIO $ + setupCommit tmp creator "group" "group" $ + users3 >>= toList . pClients + -- assert HTTP 409 on next attempt to commit + err <- testFailedCommit MessagingSetup {..} 409 + liftIO $ Wai.label err @?= "mls-stale-message" + + -- unblock epoch + runClient casClient $ deleteLock (convToGroupId (qTagUnsafe conversation)) (Epoch 2) + where + lock :: PrepQuery W (GroupId, Epoch) () + lock = "insert into mls_commit_locks (group_id, epoch) values (?, ?)" + insertLock groupId epoch = + retry x5 $ + write + lock + ( params + LocalQuorum + (groupId, epoch) + ) + unlock :: PrepQuery W (GroupId, Epoch) () + unlock = "delete from mls_commit_locks where group_id = ? and epoch = ?" + deleteLock groupId epoch = + retry x5 $ + write + unlock + ( params + LocalQuorum + (groupId, epoch) + ) + testRemoteAppMessage :: TestM () testRemoteAppMessage = withSystemTempDirectory "mls" $ \tmp -> do let opts =