diff --git a/changelog.d/5-internal/cannon-rabbitmq-pool b/changelog.d/5-internal/cannon-rabbitmq-pool new file mode 100644 index 00000000000..1e7eb39e5e2 --- /dev/null +++ b/changelog.d/5-internal/cannon-rabbitmq-pool @@ -0,0 +1 @@ +Introduce RabbitMQ connection pool in cannon diff --git a/charts/cannon/templates/configmap.yaml b/charts/cannon/templates/configmap.yaml index 99ffd6f2ede..bf085d9179f 100644 --- a/charts/cannon/templates/configmap.yaml +++ b/charts/cannon/templates/configmap.yaml @@ -29,7 +29,7 @@ data: enableTls: {{ .enableTls }} insecureSkipVerifyTls: {{ .insecureSkipVerifyTls }} {{- if .tlsCaSecretRef }} - caCert: /etc/wire/gundeck/rabbitmq-ca/{{ .tlsCaSecretRef.key }} + caCert: /etc/wire/cannon/rabbitmq-ca/{{ .tlsCaSecretRef.key }} {{- end }} {{- end }} diff --git a/charts/cannon/templates/secret.yaml b/charts/cannon/templates/secret.yaml new file mode 100644 index 00000000000..1b6f9ebd94e --- /dev/null +++ b/charts/cannon/templates/secret.yaml @@ -0,0 +1,14 @@ +apiVersion: v1 +kind: Secret +metadata: + name: cannon + labels: + app: cannon + chart: "{{ .Chart.Name }}-{{ .Chart.Version }}" + release: "{{ .Release.Name }}" + heritage: "{{ .Release.Service }}" +type: Opaque +data: + rabbitmqUsername: {{ .Values.secrets.rabbitmq.username | b64enc | quote }} + rabbitmqPassword: {{ .Values.secrets.rabbitmq.password | b64enc | quote }} + diff --git a/charts/cannon/templates/statefulset.yaml b/charts/cannon/templates/statefulset.yaml index 2931ce01b90..44566c78801 100644 --- a/charts/cannon/templates/statefulset.yaml +++ b/charts/cannon/templates/statefulset.yaml @@ -92,6 +92,17 @@ spec: {{ toYaml .Values.resources | indent 12 }} {{- end }} - name: cannon + env: + - name: RABBITMQ_USERNAME + valueFrom: + secretKeyRef: + name: cannon + key: rabbitmqUsername + - name: RABBITMQ_PASSWORD + valueFrom: + secretKeyRef: + name: cannon + key: rabbitmqPassword image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}" {{- if eq (include "includeSecurityContext" .) "true" }} securityContext: @@ -102,6 +113,10 @@ spec: mountPath: /etc/wire/cannon/externalHost - name: cannon-config mountPath: /etc/wire/cannon/conf + {{- if .Values.config.rabbitmq.tlsCaSecretRef }} + - name: rabbitmq-ca + mountPath: "/etc/wire/cannon/rabbitmq-ca/" + {{- end }} ports: - name: http containerPort: {{ .Values.service.internalPort }} @@ -155,3 +170,8 @@ spec: secret: secretName: {{ .Values.service.nginz.tls.secretName }} {{- end }} + {{- if .Values.config.rabbitmq.tlsCaSecretRef }} + - name: rabbitmq-ca + secret: + secretName: {{ .Values.config.rabbitmq.tlsCaSecretRef.name }} + {{- end }} diff --git a/charts/integration/templates/integration-integration.yaml b/charts/integration/templates/integration-integration.yaml index 701c8fec634..dd351986e7b 100644 --- a/charts/integration/templates/integration-integration.yaml +++ b/charts/integration/templates/integration-integration.yaml @@ -264,6 +264,9 @@ spec: - name: rabbitmq-ca mountPath: /etc/wire/gundeck/rabbitmq-ca + - name: rabbitmq-ca + mountPath: /etc/wire/cannon/rabbitmq-ca + {{- if eq (include "useCassandraTLS" .Values.config) "true" }} - name: "integration-cassandra" mountPath: "/certs" diff --git a/integration/test/Test/Events.hs b/integration/test/Test/Events.hs index 29f6fd4f945..2400b1c9f2c 100644 --- a/integration/test/Test/Events.hs +++ b/integration/test/Test/Events.hs @@ -6,6 +6,8 @@ import API.Common import API.Galley import API.Gundeck import qualified Control.Concurrent.Timeout as Timeout +import Control.Monad.Codensity +import Control.Monad.Trans.Class import Control.Retry import Data.ByteString.Conversion (toByteString') import qualified Data.Text as Text @@ -14,193 +16,186 @@ import qualified Network.WebSockets as WS import Notifications import SetupHelpers import Testlib.Prelude hiding (assertNoEvent) -import Testlib.Printing import UnliftIO hiding (handle) --- FUTUREWORK: Investigate why these tests are failing without --- `withModifiedBackend`; No events are received otherwise. testConsumeEventsOneWebSocket :: (HasCallStack) => App () testConsumeEventsOneWebSocket = do - withModifiedBackend def \domain -> do - alice <- randomUser domain def + alice <- randomUser OwnDomain def - lastNotifResp <- - retrying - (constantDelay 10_000 <> limitRetries 10) - (\_ resp -> pure $ resp.status == 404) - (\_ -> getLastNotification alice def) - lastNotifId <- lastNotifResp.json %. "id" & asString + lastNotifResp <- + retrying + (constantDelay 10_000 <> limitRetries 10) + (\_ resp -> pure $ resp.status == 404) + (\_ -> getLastNotification alice def) + lastNotifId <- lastNotifResp.json %. "id" & asString - client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 - clientId <- objId client + client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 + clientId <- objId client - withEventsWebSocket alice clientId $ \eventsChan ackChan -> do - deliveryTag <- assertEvent eventsChan $ \e -> do - e %. "type" `shouldMatch` "event" - e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" - e %. "data.event.payload.0.client.id" `shouldMatch` clientId - e %. "data.delivery_tag" - assertNoEvent eventsChan + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + deliveryTag <- assertEvent ws $ \e -> do + e %. "type" `shouldMatch` "event" + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` clientId + e %. "data.delivery_tag" + assertNoEvent ws - sendAck ackChan deliveryTag False - assertNoEvent eventsChan + sendAck ws deliveryTag False + assertNoEvent ws - handle <- randomHandle - putHandle alice handle >>= assertSuccess + handle <- randomHandle + putHandle alice handle >>= assertSuccess - assertEvent eventsChan $ \e -> do - e %. "type" `shouldMatch` "event" - e %. "data.event.payload.0.type" `shouldMatch` "user.update" - e %. "data.event.payload.0.user.handle" `shouldMatch` handle + assertEvent ws $ \e -> do + e %. "type" `shouldMatch` "event" + e %. "data.event.payload.0.type" `shouldMatch` "user.update" + e %. "data.event.payload.0.user.handle" `shouldMatch` handle - -- No new notifications should be stored in Cassandra as the user doesn't have - -- any legacy clients - getNotifications alice def {since = Just lastNotifId} `bindResponse` \resp -> do - resp.status `shouldMatchInt` 200 - shouldBeEmpty $ resp.json %. "notifications" + -- No new notifications should be stored in Cassandra as the user doesn't have + -- any legacy clients + getNotifications alice def {since = Just lastNotifId} `bindResponse` \resp -> do + resp.status `shouldMatchInt` 200 + shouldBeEmpty $ resp.json %. "notifications" testConsumeEventsForDifferentUsers :: (HasCallStack) => App () testConsumeEventsForDifferentUsers = do - withModifiedBackend def $ \domain -> do - alice <- randomUser domain def - bob <- randomUser domain def + alice <- randomUser OwnDomain def + bob <- randomUser OwnDomain def - aliceClient <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 - aliceClientId <- objId aliceClient + aliceClient <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 + aliceClientId <- objId aliceClient - bobClient <- addClient bob def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 - bobClientId <- objId bobClient + bobClient <- addClient bob def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 + bobClientId <- objId bobClient - withEventsWebSockets [(alice, aliceClientId), (bob, bobClientId)] $ \[(aliceEventsChan, aliceAckChan), (bobEventsChan, bobAckChan)] -> do - assertClientAdd aliceClientId aliceEventsChan aliceAckChan - assertClientAdd bobClientId bobEventsChan bobAckChan + lowerCodensity $ do + aliceWS <- createEventsWebSocket alice aliceClientId + bobWS <- createEventsWebSocket bob bobClientId + lift $ assertClientAdd aliceClientId aliceWS + lift $ assertClientAdd bobClientId bobWS where - assertClientAdd :: (HasCallStack) => String -> TChan Value -> TChan Value -> App () - assertClientAdd clientId eventsChan ackChan = do - deliveryTag <- assertEvent eventsChan $ \e -> do + assertClientAdd :: (HasCallStack) => String -> EventWebSocket -> App () + assertClientAdd clientId ws = do + deliveryTag <- assertEvent ws $ \e -> do e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" e %. "data.event.payload.0.client.id" `shouldMatch` clientId e %. "data.delivery_tag" - assertNoEvent eventsChan - sendAck ackChan deliveryTag False + assertNoEvent ws + sendAck ws deliveryTag False testConsumeEventsWhileHavingLegacyClients :: (HasCallStack) => App () testConsumeEventsWhileHavingLegacyClients = do - withModifiedBackend def $ \domain -> do - alice <- randomUser domain def + alice <- randomUser OwnDomain def - -- Even if alice has no clients, the notifications should still be persisted - -- in Cassandra. This choice is kinda arbitrary as these notifications - -- probably don't mean much, however, it ensures backwards compatibility. - lastNotifId <- - awaitNotification alice noValue (const $ pure True) >>= \notif -> do - notif %. "payload.0.type" `shouldMatch` "user.activate" - -- There is only one notification (at the time of writing), so we assume - -- it to be the last one. - notif %. "id" & asString + -- Even if alice has no clients, the notifications should still be persisted + -- in Cassandra. This choice is kinda arbitrary as these notifications + -- probably don't mean much, however, it ensures backwards compatibility. + lastNotifId <- + awaitNotification alice noValue (const $ pure True) >>= \notif -> do + notif %. "payload.0.type" `shouldMatch` "user.activate" + -- There is only one notification (at the time of writing), so we assume + -- it to be the last one. + notif %. "id" & asString - oldClient <- addClient alice def {acapabilities = Just []} >>= getJSON 201 + oldClient <- addClient alice def {acapabilities = Just []} >>= getJSON 201 - withWebSocket (alice, "anything-but-conn", oldClient %. "id") $ \oldWS -> do - newClient <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 - newClientId <- newClient %. "id" & asString + withWebSocket (alice, "anything-but-conn", oldClient %. "id") $ \oldWS -> do + newClient <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 + newClientId <- newClient %. "id" & asString - oldNotif <- awaitMatch isUserClientAddNotif oldWS - oldNotif %. "payload.0.client.id" `shouldMatch` newClientId + oldNotif <- awaitMatch isUserClientAddNotif oldWS + oldNotif %. "payload.0.client.id" `shouldMatch` newClientId - withEventsWebSocket alice newClientId $ \eventsChan _ -> - assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" - e %. "data.event.payload.0.client.id" `shouldMatch` newClientId + runCodensity (createEventsWebSocket alice newClientId) $ \ws -> + assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` newClientId - -- All notifs are also in Cassandra because of the legacy client - getNotifications alice def {since = Just lastNotifId} `bindResponse` \resp -> do - resp.status `shouldMatchInt` 200 - resp.json %. "notifications.0.payload.0.type" `shouldMatch` "user.client-add" - resp.json %. "notifications.1.payload.0.type" `shouldMatch` "user.client-add" + -- All notifs are also in Cassandra because of the legacy client + getNotifications alice def {since = Just lastNotifId} `bindResponse` \resp -> do + resp.status `shouldMatchInt` 200 + resp.json %. "notifications.0.payload.0.type" `shouldMatch` "user.client-add" + resp.json %. "notifications.1.payload.0.type" `shouldMatch` "user.client-add" testConsumeEventsAcks :: (HasCallStack) => App () testConsumeEventsAcks = do - withModifiedBackend def $ \domain -> do - alice <- randomUser domain def - client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 - clientId <- objId client - - withEventsWebSocket alice clientId $ \eventsChan _ackChan -> do - assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" - e %. "data.event.payload.0.client.id" `shouldMatch` clientId - - -- without ack, we receive the same event again - withEventsWebSocket alice clientId $ \eventsChan ackChan -> do - deliveryTag <- assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" - e %. "data.event.payload.0.client.id" `shouldMatch` clientId - e %. "data.delivery_tag" - sendAck ackChan deliveryTag False - - withEventsWebSocket alice clientId $ \eventsChan _ -> do - assertNoEvent eventsChan + alice <- randomUser OwnDomain def + client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 + clientId <- objId client + + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` clientId + + -- without ack, we receive the same event again + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + deliveryTag <- assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` clientId + e %. "data.delivery_tag" + sendAck ws deliveryTag False + + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + assertNoEvent ws testConsumeEventsMultipleAcks :: (HasCallStack) => App () testConsumeEventsMultipleAcks = do - withModifiedBackend def $ \domain -> do - alice <- randomUser domain def - client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 - clientId <- objId client + alice <- randomUser OwnDomain def + client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 + clientId <- objId client - handle <- randomHandle - putHandle alice handle >>= assertSuccess + handle <- randomHandle + putHandle alice handle >>= assertSuccess - withEventsWebSocket alice clientId $ \eventsChan ackChan -> do - assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" - e %. "data.event.payload.0.client.id" `shouldMatch` clientId + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` clientId - deliveryTag <- assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "user.update" - e %. "data.event.payload.0.user.handle" `shouldMatch` handle - e %. "data.delivery_tag" + deliveryTag <- assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.update" + e %. "data.event.payload.0.user.handle" `shouldMatch` handle + e %. "data.delivery_tag" - sendAck ackChan deliveryTag True + sendAck ws deliveryTag True - withEventsWebSocket alice clientId $ \eventsChan _ -> do - assertNoEvent eventsChan + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + assertNoEvent ws testConsumeEventsAckNewEventWithoutAckingOldOne :: (HasCallStack) => App () testConsumeEventsAckNewEventWithoutAckingOldOne = do - withModifiedBackend def $ \domain -> do - alice <- randomUser domain def - client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 - clientId <- objId client + alice <- randomUser OwnDomain def + client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 + clientId <- objId client - handle <- randomHandle - putHandle alice handle >>= assertSuccess + handle <- randomHandle + putHandle alice handle >>= assertSuccess - withEventsWebSocket alice clientId $ \eventsChan ackChan -> do - assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" - e %. "data.event.payload.0.client.id" `shouldMatch` clientId + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` clientId - deliveryTagHandleAdd <- assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "user.update" - e %. "data.event.payload.0.user.handle" `shouldMatch` handle - e %. "data.delivery_tag" + deliveryTagHandleAdd <- assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.update" + e %. "data.event.payload.0.user.handle" `shouldMatch` handle + e %. "data.delivery_tag" - -- Only ack the handle add delivery tag - sendAck ackChan deliveryTagHandleAdd False + -- Only ack the handle add delivery tag + sendAck ws deliveryTagHandleAdd False - -- Expect client-add event to be delivered again. - withEventsWebSocket alice clientId $ \eventsChan ackChan -> do - deliveryTagClientAdd <- assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" - e %. "data.event.payload.0.client.id" `shouldMatch` clientId - e %. "data.delivery_tag" + -- Expect client-add event to be delivered again. + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + deliveryTagClientAdd <- assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` clientId + e %. "data.delivery_tag" - sendAck ackChan deliveryTagClientAdd False + sendAck ws deliveryTagClientAdd False - withEventsWebSocket alice clientId $ \eventsChan _ -> do - assertNoEvent eventsChan + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + assertNoEvent ws testEventsDeadLettered :: (HasCallStack) => App () testEventsDeadLettered = do @@ -219,22 +214,22 @@ testEventsDeadLettered = do handle1 <- randomHandle putHandle alice handle1 >>= assertSuccess - withEventsWebSocket alice clientId $ \eventsChan ackChan -> do - assertEvent eventsChan $ \e -> do + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + assertEvent ws $ \e -> do e %. "type" `shouldMatch` "notifications.missed" -- Until we ack the full sync, we can't get new events - ackFullSync ackChan + ackFullSync ws -- withEventsWebSocket alice clientId $ \eventsChan ackChan -> do -- Now we can see the next event - assertEvent eventsChan $ \e -> do + assertEvent ws $ \e -> do e %. "data.event.payload.0.type" `shouldMatch` "user.update" e %. "data.event.payload.0.user.handle" `shouldMatch` handle1 - ackEvent ackChan e + ackEvent ws e -- We've consumed the whole queue. - assertNoEvent eventsChan + assertNoEvent ws testTransientEventsDoNotTriggerDeadLetters :: (HasCallStack) => App () testTransientEventsDoNotTriggerDeadLetters = do @@ -246,14 +241,14 @@ testTransientEventsDoNotTriggerDeadLetters = do clientId <- objId client -- consume it - withEventsWebSocket alice clientId $ \eventsChan ackChan -> do - assertEvent eventsChan $ \e -> do + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + assertEvent ws $ \e -> do e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" e %. "type" `shouldMatch` "event" e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" e %. "data.event.payload.0.client.id" `shouldMatch` clientId deliveryTag <- e %. "data.delivery_tag" - sendAck ackChan deliveryTag False + sendAck ws deliveryTag False -- Self conv ID is same as user's ID, we'll use this to send typing -- indicators, so we don't have to create another conv. @@ -261,107 +256,153 @@ testTransientEventsDoNotTriggerDeadLetters = do -- Typing status is transient, currently no one is listening. sendTypingStatus alice selfConvId "started" >>= assertSuccess - withEventsWebSocket alice clientId $ \eventsChan _ackChan -> do - assertNoEvent eventsChan + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + assertNoEvent ws testTransientEvents :: (HasCallStack) => App () testTransientEvents = do - withModifiedBackend def $ \domain -> do - alice <- randomUser domain def - client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 - clientId <- objId client - - -- Self conv ID is same as user's ID, we'll use this to send typing - -- indicators, so we don't have to create another conv. - selfConvId <- objQidObject alice - - withEventsWebSocket alice clientId $ \eventsChan ackChan -> do - consumeAllEvents eventsChan ackChan - sendTypingStatus alice selfConvId "started" >>= assertSuccess - assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "conversation.typing" - e %. "data.event.payload.0.qualified_conversation" `shouldMatch` selfConvId - deliveryTag <- e %. "data.delivery_tag" - sendAck ackChan deliveryTag False + alice <- randomUser OwnDomain def + client <- addClient alice def {acapabilities = Just ["consumable-notifications"]} >>= getJSON 201 + clientId <- objId client - handle1 <- randomHandle - putHandle alice handle1 >>= assertSuccess + -- Self conv ID is same as user's ID, we'll use this to send typing + -- indicators, so we don't have to create another conv. + selfConvId <- objQidObject alice - sendTypingStatus alice selfConvId "stopped" >>= assertSuccess - - handle2 <- randomHandle - putHandle alice handle2 >>= assertSuccess - - -- We shouldn't see the stopped typing status because we were not connected to - -- the websocket when it was sent. The other events should still show up in - -- order. - withEventsWebSocket alice clientId $ \eventsChan ackChan -> do - for_ [handle1, handle2] $ \handle -> - assertEvent eventsChan $ \e -> do - e %. "data.event.payload.0.type" `shouldMatch` "user.update" - e %. "data.event.payload.0.user.handle" `shouldMatch` handle - ackEvent ackChan e + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + consumeAllEvents ws + sendTypingStatus alice selfConvId "started" >>= assertSuccess + assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "conversation.typing" + e %. "data.event.payload.0.qualified_conversation" `shouldMatch` selfConvId + deliveryTag <- e %. "data.delivery_tag" + sendAck ws deliveryTag False + + handle1 <- randomHandle + putHandle alice handle1 >>= assertSuccess + + sendTypingStatus alice selfConvId "stopped" >>= assertSuccess + + handle2 <- randomHandle + putHandle alice handle2 >>= assertSuccess + + -- We shouldn't see the stopped typing status because we were not connected to + -- the websocket when it was sent. The other events should still show up in + -- order. + runCodensity (createEventsWebSocket alice clientId) $ \ws -> do + for_ [handle1, handle2] $ \handle -> + assertEvent ws $ \e -> do + e %. "data.event.payload.0.type" `shouldMatch` "user.update" + e %. "data.event.payload.0.user.handle" `shouldMatch` handle + ackEvent ws e + + assertNoEvent ws + +testChannelLimit :: (HasCallStack) => App () +testChannelLimit = withModifiedBackend + ( def + { cannonCfg = + setField "rabbitMqMaxChannels" (2 :: Int) + >=> setField "rabbitMqMaxConnections" (1 :: Int) + } + ) + $ \domain -> do + alice <- randomUser domain def + (client0 : clients) <- + replicateM 3 + $ addClient alice def {acapabilities = Just ["consumable-notifications"]} + >>= getJSON 201 + >>= (%. "id") + >>= asString + + lowerCodensity $ do + for_ clients $ \c -> do + ws <- createEventsWebSocket alice c + e <- Codensity $ \k -> assertEvent ws k + lift $ do + e %. "data.event.payload.0.type" `shouldMatch` "user.client-add" + e %. "data.event.payload.0.client.id" `shouldMatch` c + e %. "data.delivery_tag" - assertNoEvent eventsChan + -- the first client fails to connect because the server runs out of channels + do + ws <- createEventsWebSocket alice client0 + lift $ assertNoEvent ws ---------------------------------------------------------------------- -- helpers -withEventsWebSockets :: forall uid a. (HasCallStack, MakesValue uid) => [(uid, String)] -> ([(TChan Value, TChan Value)] -> App a) -> App a -withEventsWebSockets userClients k = go [] $ reverse userClients - where - go :: [(TChan Value, TChan Value)] -> [(uid, String)] -> App a - go chans [] = k chans - go chans ((uid, cid) : remaining) = - withEventsWebSocket uid cid $ \eventsChan ackChan -> - go ((eventsChan, ackChan) : chans) remaining - -withEventsWebSocket :: (HasCallStack, MakesValue uid) => uid -> String -> (TChan Value -> TChan Value -> App a) -> App a -withEventsWebSocket uid cid k = do - closeWS <- newEmptyMVar - bracket (setup closeWS) (\(_, _, wsThread) -> cancel wsThread) $ \(eventsChan, ackChan, wsThread) -> do - x <- k eventsChan ackChan - - -- Ensure all the acks are sent before closing the websocket - isAckChanEmpty <- - retrying - (limitRetries 5 <> constantDelay 10_000) - (\_ isEmpty -> pure $ not isEmpty) - (\_ -> atomically $ isEmptyTChan ackChan) - unless isAckChanEmpty $ do - putStrLn $ colored yellow $ "The ack chan is not empty after 50ms, some acks may not make it to the server" - - void $ tryPutMVar closeWS () - - timeout 1_000_000 (wait wsThread) >>= \case - Nothing -> - putStrLn $ colored yellow $ "The websocket thread did not close after waiting for 1s" - Just () -> pure () - - pure x - where - setup :: (HasCallStack) => MVar () -> App (TChan Value, TChan Value, Async ()) - setup closeWS = do - (eventsChan, ackChan) <- liftIO $ (,) <$> newTChanIO <*> newTChanIO - wsThread <- eventsWebSocket uid cid eventsChan ackChan closeWS - pure (eventsChan, ackChan, wsThread) - -sendMsg :: (HasCallStack) => TChan Value -> Value -> App () -sendMsg eventsChan msg = liftIO $ atomically $ writeTChan eventsChan msg - -ackFullSync :: (HasCallStack) => TChan Value -> App () -ackFullSync ackChan = do - sendMsg ackChan - $ object ["type" .= "ack_full_sync"] - -ackEvent :: (HasCallStack) => TChan Value -> Value -> App () -ackEvent ackChan event = do +data EventWebSocket = EventWebSocket + { events :: Chan (Either WS.ConnectionException Value), + ack :: MVar (Maybe Value) + } + +createEventsWebSocket :: + (HasCallStack, MakesValue uid) => + uid -> + String -> + Codensity App EventWebSocket +createEventsWebSocket user cid = do + eventsChan <- liftIO newChan + ackChan <- liftIO newEmptyMVar + serviceMap <- lift $ getServiceMap =<< objDomain user + uid <- lift $ objId =<< objQidObject user + let HostPort caHost caPort = serviceHostPort serviceMap Cannon + path = "/events?client=" <> cid + caHdrs = [(fromString "Z-User", toByteString' uid)] + app conn = + race_ + (wsRead conn `catch` (writeChan eventsChan . Left)) + (wsWrite conn) + + wsRead conn = forever $ do + bs <- WS.receiveData conn + case decodeStrict' bs of + Just n -> writeChan eventsChan (Right n) + Nothing -> + error $ "Failed to decode events: " ++ show bs + + wsWrite conn = do + mAck <- takeMVar ackChan + case mAck of + Nothing -> WS.sendClose conn (Text.pack "") + Just ack -> + WS.sendBinaryData conn (encode ack) + >> wsWrite conn + + wsThread <- Codensity $ \k -> do + withAsync + ( liftIO + $ WS.runClientWith + caHost + (fromIntegral caPort) + path + WS.defaultConnectionOptions + caHdrs + app + ) + k + + Codensity $ \k -> + k (EventWebSocket eventsChan ackChan) `finally` do + putMVar ackChan Nothing + liftIO $ wait wsThread + +ackFullSync :: (HasCallStack) => EventWebSocket -> App () +ackFullSync ws = + putMVar ws.ack + $ Just (object ["type" .= "ack_full_sync"]) + +ackEvent :: (HasCallStack) => EventWebSocket -> Value -> App () +ackEvent ws event = do deliveryTag <- event %. "data.delivery_tag" - sendAck ackChan deliveryTag False + sendAck ws deliveryTag False -sendAck :: (HasCallStack) => TChan Value -> Value -> Bool -> App () -sendAck ackChan deliveryTag multiple = do - sendMsg ackChan +sendAck :: (HasCallStack) => EventWebSocket -> Value -> Bool -> App () +sendAck ws deliveryTag multiple = + do + putMVar $ ws.ack + $ Just $ object [ "type" .= "ack", "data" @@ -371,65 +412,33 @@ sendAck ackChan deliveryTag multiple = do ] ] -assertEvent :: (HasCallStack) => TChan Value -> ((HasCallStack) => Value -> App a) -> App a -assertEvent eventsChan expectations = do - timeout 10_000_000 (atomically (readTChan eventsChan)) >>= \case - Nothing -> assertFailure "No event received for 10s" - Just e -> do +assertEvent :: (HasCallStack) => EventWebSocket -> ((HasCallStack) => Value -> App a) -> App a +assertEvent ws expectations = do + timeout 10_000_000 (readChan ws.events) >>= \case + Nothing -> assertFailure "No event received for 1s" + Just (Left _) -> assertFailure "Websocket closed when waiting for more events" + Just (Right e) -> do pretty <- prettyJSON e addFailureContext ("event:\n" <> pretty) $ expectations e -assertNoEvent :: (HasCallStack) => TChan Value -> App () -assertNoEvent eventsChan = do - timeout 1_000_000 (atomically (readTChan eventsChan)) >>= \case +assertNoEvent :: (HasCallStack) => EventWebSocket -> App () +assertNoEvent ws = do + timeout 1_000_000 (readChan ws.events) >>= \case Nothing -> pure () - Just e -> do + Just (Left _) -> pure () + Just (Right e) -> do eventJSON <- prettyJSON e assertFailure $ "Did not expect event: \n" <> eventJSON -consumeAllEvents :: TChan Value -> TChan Value -> App () -consumeAllEvents eventsChan ackChan = do - timeout 1_000_000 (atomically (readTChan eventsChan)) >>= \case +consumeAllEvents :: EventWebSocket -> App () +consumeAllEvents ws = do + timeout 1_000_000 (readChan ws.events) >>= \case Nothing -> pure () - Just e -> do - ackEvent ackChan e - consumeAllEvents eventsChan ackChan - -eventsWebSocket :: (MakesValue user) => user -> String -> TChan Value -> TChan Value -> MVar () -> App (Async ()) -eventsWebSocket user clientId eventsChan ackChan closeWS = do - serviceMap <- getServiceMap =<< objDomain user - uid <- objId =<< objQidObject user - let HostPort caHost caPort = serviceHostPort serviceMap Cannon - path = "/events?client=" <> clientId - caHdrs = [(fromString "Z-User", toByteString' uid)] - app conn = do - r <- - async $ wsRead conn `catch` \(e :: WS.ConnectionException) -> - case e of - WS.CloseRequest {} -> pure () - _ -> throwIO e - w <- async $ wsWrite conn - void $ waitAny [r, w] - - wsRead conn = forever $ do - bs <- WS.receiveData conn - case decodeStrict' bs of - Just n -> atomically $ writeTChan eventsChan n - Nothing -> - error $ "Failed to decode events: " ++ show bs - - wsWrite conn = forever $ do - eitherAck <- race (readMVar closeWS) (atomically $ readTChan ackChan) - case eitherAck of - Left () -> WS.sendClose conn (Text.pack "") - Right ack -> WS.sendBinaryData conn (encode ack) - liftIO - $ async - $ WS.runClientWith - caHost - (fromIntegral caPort) - path - WS.defaultConnectionOptions - caHdrs - app + Just (Left e) -> + assertFailure + $ "Websocket closed while consuming all events: " + <> displayException e + Just (Right e) -> do + ackEvent ws e + consumeAllEvents ws diff --git a/libs/extended/src/Network/AMQP/Extended.hs b/libs/extended/src/Network/AMQP/Extended.hs index 4aa48aefc5b..1453f3909e4 100644 --- a/libs/extended/src/Network/AMQP/Extended.hs +++ b/libs/extended/src/Network/AMQP/Extended.hs @@ -11,6 +11,8 @@ module Network.AMQP.Extended demoteOpts, RabbitMqTlsOpts (..), mkConnectionOpts, + mkTLSSettings, + readCredsFromEnv, ) where diff --git a/services/background-worker/src/Wire/BackendDeadUserNotificationWatcher.hs b/services/background-worker/src/Wire/BackendDeadUserNotificationWatcher.hs index ad8c3c38254..1d24ab05c6e 100644 --- a/services/background-worker/src/Wire/BackendDeadUserNotificationWatcher.hs +++ b/services/background-worker/src/Wire/BackendDeadUserNotificationWatcher.hs @@ -99,15 +99,25 @@ startWorker amqp = do -- If the mvar is filled with a connection, we know the connection itself is fine, -- so we only need to re-open the channel let openConnection connM = do + -- keep track of whether the connection is being closed normally + closingRef <- newIORef False + mConn <- lowerCodensity $ do conn <- case connM of Nothing -> do -- Open the rabbit mq connection - conn <- Codensity $ bracket (liftIO $ Q.openConnection'' connOpts) (liftIO . Q.closeConnection) + conn <- Codensity + $ bracket + (liftIO $ Q.openConnection'' connOpts) + $ \conn -> do + writeIORef closingRef True + liftIO $ Q.closeConnection conn -- We need to recover from connection closed by restarting it liftIO $ Q.addConnectionClosedHandler conn True do - Log.err env.logger $ - Log.msg (Log.val "BackendDeadUserNoticationWatcher: Connection closed.") + closing <- readIORef closingRef + unless closing $ do + Log.err env.logger $ + Log.msg (Log.val "BackendDeadUserNoticationWatcher: Connection closed.") putMVar mVar Nothing runAppT env $ markAsNotWorking BackendDeadUserNoticationWatcher pure conn @@ -118,9 +128,10 @@ startWorker amqp = do -- If the channel stops, we need to re-open liftIO $ Q.addChannelExceptionHandler chan $ \e -> do - Log.err env.logger $ - Log.msg (Log.val "BackendDeadUserNoticationWatcher: Caught exception in RabbitMQ channel.") - . Log.field "exception" (displayException e) + unless (Q.isNormalChannelClose e) $ + Log.err env.logger $ + Log.msg (Log.val "BackendDeadUserNoticationWatcher: Caught exception in RabbitMQ channel.") + . Log.field "exception" (displayException e) runAppT env $ markAsNotWorking BackendDeadUserNoticationWatcher putMVar mVar (Just conn) diff --git a/services/cannon/cannon.cabal b/services/cannon/cannon.cabal index 4091540a846..25ad0624593 100644 --- a/services/cannon/cannon.cabal +++ b/services/cannon/cannon.cabal @@ -23,6 +23,7 @@ library Cannon.App Cannon.Dict Cannon.Options + Cannon.RabbitMq Cannon.RabbitMqConsumerApp Cannon.Run Cannon.Types @@ -85,10 +86,12 @@ library , async >=2.0 , base >=4.6 && <5 , bilge >=0.12 + , binary , bytestring >=0.10 , bytestring-conversion >=0.2 , cassandra-util , conduit >=1.3.4.2 + , containers , data-timeout >=0.3 , exceptions >=0.6 , extended @@ -111,6 +114,7 @@ library , strict >=0.3.2 , text >=1.1 , tinylog >=0.10 + , transformers , types-common >=0.16 , unix , unliftio diff --git a/services/cannon/default.nix b/services/cannon/default.nix index 80ad8b8e3ca..c62056faa23 100644 --- a/services/cannon/default.nix +++ b/services/cannon/default.nix @@ -9,10 +9,12 @@ , async , base , bilge +, binary , bytestring , bytestring-conversion , cassandra-util , conduit +, containers , criterion , data-timeout , exceptions @@ -43,6 +45,7 @@ , tasty-quickcheck , text , tinylog +, transformers , types-common , unix , unliftio @@ -69,10 +72,12 @@ mkDerivation { async base bilge + binary bytestring bytestring-conversion cassandra-util conduit + containers data-timeout exceptions extended @@ -95,6 +100,7 @@ mkDerivation { strict text tinylog + transformers types-common unix unliftio diff --git a/services/cannon/src/Cannon/Options.hs b/services/cannon/src/Cannon/Options.hs index dad2ad51924..aa1d1e8d005 100644 --- a/services/cannon/src/Cannon/Options.hs +++ b/services/cannon/src/Cannon/Options.hs @@ -32,12 +32,15 @@ module Cannon.Options drainOpts, rabbitmq, cassandraOpts, + rabbitMqMaxConnections, + rabbitMqMaxChannels, Opts, gracePeriodSeconds, millisecondsBetweenBatches, minBatchSize, disabledAPIVersions, DrainOpts, + validateOpts, ) where @@ -98,12 +101,23 @@ data Opts = Opts _optsLogFormat :: !(Maybe (Last LogFormat)), _optsDrainOpts :: DrainOpts, _optsDisabledAPIVersions :: !(Set VersionExp), - _optsCassandraOpts :: !CassandraOpts + _optsCassandraOpts :: !CassandraOpts, + -- | Maximum number of rabbitmq connections. Must be strictly positive. + _optsRabbitMqMaxConnections :: Int, + -- | Maximum number of rabbitmq channels per connection. Must be strictly positive. + _optsRabbitMqMaxChannels :: Int } deriving (Show, Generic) makeFields ''Opts +validateOpts :: Opts -> IO () +validateOpts opts = do + when (opts._optsRabbitMqMaxConnections <= 0) $ do + fail "rabbitMqMaxConnections must be strictly positive" + when (opts._optsRabbitMqMaxChannels <= 0) $ do + fail "rabbitMqMaxChannels must be strictly positive" + instance FromJSON Opts where parseJSON = withObject "CannonOpts" $ \o -> Opts @@ -116,3 +130,5 @@ instance FromJSON Opts where <*> o .: "drainOpts" <*> o .: "disabledAPIVersions" <*> o .: "cassandra" + <*> o .:? "rabbitMqMaxConnections" .!= 1000 + <*> o .:? "rabbitMqMaxChannels" .!= 300 diff --git a/services/cannon/src/Cannon/RabbitMq.hs b/services/cannon/src/Cannon/RabbitMq.hs new file mode 100644 index 00000000000..d5a228bd410 --- /dev/null +++ b/services/cannon/src/Cannon/RabbitMq.hs @@ -0,0 +1,321 @@ +{-# LANGUAGE RecordWildCards #-} + +module Cannon.RabbitMq + ( RabbitMqPoolException, + RabbitMqPoolOptions (..), + RabbitMqPool, + createRabbitMqPool, + drainRabbitMqPool, + RabbitMqChannel (..), + createChannel, + getMessage, + ackMessage, + ) +where + +import Cannon.Options +import Control.Concurrent.Async +import Control.Concurrent.Timeout +import Control.Exception +import Control.Lens ((^.)) +import Control.Monad.Codensity +import Control.Monad.Trans.Except +import Control.Monad.Trans.Maybe +import Control.Retry +import Data.ByteString.Conversion +import Data.List.Extra +import Data.Map qualified as Map +import Data.Timeout +import Imports hiding (threadDelay) +import Network.AMQP qualified as Q +import Network.AMQP.Extended +import System.Logger (Logger) +import System.Logger qualified as Log +import UnliftIO (pooledMapConcurrentlyN_) + +data RabbitMqPoolException + = TooManyChannels + | ChannelClosed + deriving (Eq, Show) + +instance Exception RabbitMqPoolException + +data PooledConnection key = PooledConnection + { connId :: Word64, + inner :: Q.Connection, + channels :: !(Map key Q.Channel) + } + +data RabbitMqPool key = RabbitMqPool + { opts :: RabbitMqPoolOptions, + nextId :: TVar Word64, + connections :: TVar [PooledConnection key], + -- | draining mode + draining :: TVar Bool, + logger :: Logger, + deadVar :: MVar () + } + +data RabbitMqPoolOptions = RabbitMqPoolOptions + { maxConnections :: Int, + maxChannels :: Int, + endpoint :: AmqpEndpoint + } + +createRabbitMqPool :: (Ord key) => RabbitMqPoolOptions -> Logger -> Codensity IO (RabbitMqPool key) +createRabbitMqPool opts logger = Codensity $ bracket create destroy + where + create = do + deadVar <- newEmptyMVar + (nextId, connections, draining) <- + atomically $ + (,,) <$> newTVar 0 <*> newTVar [] <*> newTVar False + let pool = RabbitMqPool {..} + -- create one connection + void $ createConnection pool + pure pool + destroy pool = putMVar pool.deadVar () + +drainRabbitMqPool :: (ToByteString key) => RabbitMqPool key -> DrainOpts -> IO () +drainRabbitMqPool pool opts = do + atomically $ writeTVar pool.draining True + + channels <- atomically $ do + conns <- readTVar pool.connections + pure $ concat [Map.assocs c.channels | c <- conns] + let numberOfChannels = fromIntegral (length channels) + + let maxNumberOfBatches = + (opts ^. gracePeriodSeconds * 1000) + `div` (opts ^. millisecondsBetweenBatches) + computedBatchSize = numberOfChannels `div` maxNumberOfBatches + batchSize = max (opts ^. minBatchSize) computedBatchSize + + logDraining + pool.logger + numberOfChannels + batchSize + (opts ^. minBatchSize) + computedBatchSize + maxNumberOfBatches + + -- Sleep for the grace period + 1 second. If the sleep completes, it means + -- that draining didn't finish, and we should log that. + withAsync + ( do + -- Allocate 1 second more than the grace period to allow for overhead of + -- spawning threads. + liftIO $ threadDelay $ ((opts ^. gracePeriodSeconds) # Second + 1 # Second) + logExpired pool.logger (opts ^. gracePeriodSeconds) + ) + $ \_ -> do + for_ (chunksOf (fromIntegral batchSize) channels) $ \batch -> do + -- 16 was chosen with a roll of a fair dice. + concurrently + (pooledMapConcurrentlyN_ 16 (closeChannel pool.logger) batch) + (liftIO $ threadDelay ((opts ^. millisecondsBetweenBatches) # MilliSecond)) + Log.info pool.logger $ Log.msg (Log.val "Draining complete") + where + closeChannel :: (ToByteString key) => Log.Logger -> (key, Q.Channel) -> IO () + closeChannel l (key, chan) = do + Log.info l $ + Log.msg (Log.val "closing rabbitmq channel") + . Log.field "key" (toByteString' key) + Q.closeChannel chan + + logExpired :: Log.Logger -> Word64 -> IO () + logExpired l period = do + Log.err l $ Log.msg (Log.val "Drain grace period expired") . Log.field "gracePeriodSeconds" period + + logDraining :: Log.Logger -> Word64 -> Word64 -> Word64 -> Word64 -> Word64 -> IO () + logDraining l count b minB batchSize m = do + Log.info l $ + Log.msg (Log.val "draining all rabbitmq channels") + . Log.field "numberOfChannels" count + . Log.field "computedBatchSize" b + . Log.field "minBatchSize" minB + . Log.field "batchSize" batchSize + . Log.field "maxNumberOfBatches" m + +createConnection :: (Ord key) => RabbitMqPool key -> IO (PooledConnection key) +createConnection pool = mask_ $ do + conn <- openConnection pool + mpconn <- runMaybeT . atomically $ do + -- do not create new connections when in draining mode + readTVar pool.draining >>= guard . not + connId <- readTVar pool.nextId + writeTVar pool.nextId $! succ connId + let c = + PooledConnection + { connId = connId, + channels = mempty, + inner = conn + } + modifyTVar pool.connections (c :) + pure c + pconn <- maybe (throwIO TooManyChannels) pure mpconn + + closedVar <- newEmptyMVar + -- Fire and forget: the thread will terminate by itself as soon as the + -- connection is closed (or if the pool is destroyed). + -- Asynchronous exception safety is guaranteed because exceptions are masked + -- in this whole block. + void . async $ do + v <- race (takeMVar closedVar) (readMVar pool.deadVar) + when (isRight v) $ + -- close connection and ignore exceptions + catch @SomeException (Q.closeConnection conn) $ + \_ -> pure () + atomically $ do + conns <- readTVar pool.connections + writeTVar pool.connections $ + filter (\c -> c.connId /= pconn.connId) conns + Q.addConnectionClosedHandler conn True $ do + putMVar closedVar () + pure pconn + +openConnection :: RabbitMqPool key -> IO Q.Connection +openConnection pool = do + (username, password) <- readCredsFromEnv + recovering + rabbitMqRetryPolicy + ( skipAsyncExceptions + <> [logRetries (const $ pure True) (logConnectionError pool.logger)] + ) + ( const $ do + Log.info pool.logger $ + Log.msg (Log.val "Trying to connect to RabbitMQ") + mTlsSettings <- + traverse + (liftIO . (mkTLSSettings pool.opts.endpoint.host)) + pool.opts.endpoint.tls + liftIO $ + Q.openConnection'' $ + Q.defaultConnectionOpts + { Q.coServers = + [ ( pool.opts.endpoint.host, + fromIntegral pool.opts.endpoint.port + ) + ], + Q.coVHost = pool.opts.endpoint.vHost, + Q.coAuth = [Q.plain username password], + Q.coTLSSettings = fmap Q.TLSCustom mTlsSettings + } + ) + +data RabbitMqChannel = RabbitMqChannel + { -- | The current channel. The var is empty while the channel is being + -- re-established. + inner :: MVar Q.Channel, + msgVar :: MVar (Maybe (Q.Message, Q.Envelope)) + } + +getMessage :: RabbitMqChannel -> IO (Q.Message, Q.Envelope) +getMessage chan = takeMVar chan.msgVar >>= maybe (throwIO ChannelClosed) pure + +ackMessage :: RabbitMqChannel -> Word64 -> Bool -> IO () +ackMessage chan deliveryTag multiple = do + inner <- readMVar chan.inner + Q.ackMsg inner deliveryTag multiple + +createChannel :: (Ord key) => RabbitMqPool key -> Text -> key -> Codensity IO RabbitMqChannel +createChannel pool queue key = do + closedVar <- lift newEmptyMVar + inner <- lift newEmptyMVar + msgVar <- lift newEmptyMVar + + let handleException e = do + retry <- case (Q.isNormalChannelClose e, fromException e) of + (True, _) -> do + Log.info pool.logger $ + Log.msg (Log.val "RabbitMQ channel is closed normally, not attempting to reopen channel") + pure False + (_, Just (Q.ConnectionClosedException {})) -> do + Log.info pool.logger $ + Log.msg (Log.val "RabbitMQ connection was closed unexpectedly") + pure True + _ -> do + unless (fromException e == Just AsyncCancelled) $ + logException pool.logger "RabbitMQ channel closed" e + pure True + putMVar closedVar retry + + let manageChannel = do + retry <- lowerCodensity $ do + conn <- Codensity $ bracket (acquireConnection pool) (releaseConnection pool key) + chan <- Codensity $ bracket (Q.openChannel conn.inner) $ \c -> + catch (Q.closeChannel c) $ \(_ :: SomeException) -> pure () + connSize <- atomically $ do + let conn' = conn {channels = Map.insert key chan conn.channels} + conns <- readTVar pool.connections + writeTVar pool.connections $! + map (\c -> if c.connId == conn'.connId then conn' else c) conns + pure $ Map.size conn'.channels + if connSize > pool.opts.maxChannels + then pure True + else do + liftIO $ Q.addChannelExceptionHandler chan handleException + putMVar inner chan + void $ liftIO $ Q.consumeMsgs chan queue Q.Ack $ \(message, envelope) -> do + putMVar msgVar (Just (message, envelope)) + takeMVar closedVar + + when retry manageChannel + + void $ + Codensity $ + withAsync $ + catch manageChannel handleException + `finally` putMVar msgVar Nothing + pure RabbitMqChannel {inner = inner, msgVar = msgVar} + +acquireConnection :: (Ord key) => RabbitMqPool key -> IO (PooledConnection key) +acquireConnection pool = do + findConnection pool >>= \case + Nothing -> do + bracketOnError + (createConnection pool) + (Q.closeConnection . (.inner)) + $ \conn -> do + -- if we have too many connections at this point, give up + numConnections <- atomically $ length <$> readTVar pool.connections + when (numConnections > pool.opts.maxConnections) $ + throw TooManyChannels + pure conn + Just conn -> pure conn + +findConnection :: RabbitMqPool key -> IO (Maybe (PooledConnection key)) +findConnection pool = (either throwIO pure <=< (atomically . runExceptT . runMaybeT)) $ do + conns <- lift . lift $ readTVar pool.connections + guard (notNull conns) + + let pconn = minimumOn (Map.size . (.channels)) $ conns + when (Map.size pconn.channels >= pool.opts.maxChannels) $ + if length conns >= pool.opts.maxConnections + then lift $ throwE TooManyChannels + else mzero + pure pconn + +releaseConnection :: (Ord key) => RabbitMqPool key -> key -> PooledConnection key -> IO () +releaseConnection pool key conn = atomically $ do + modifyTVar pool.connections $ map $ \c -> + if c.connId == conn.connId + then c {channels = Map.delete key c.channels} + else c + +logConnectionError :: Logger -> Bool -> SomeException -> RetryStatus -> IO () +logConnectionError l willRetry e retryStatus = do + Log.err l $ + Log.msg (Log.val "Failed to connect to RabbitMQ") + . Log.field "error" (displayException @SomeException e) + . Log.field "willRetry" willRetry + . Log.field "retryCount" retryStatus.rsIterNumber + +logException :: (MonadIO m) => Logger -> String -> SomeException -> m () +logException l m (SomeException e) = do + Log.err l $ + Log.msg m + . Log.field "error" (displayException e) + +rabbitMqRetryPolicy :: RetryPolicyM IO +rabbitMqRetryPolicy = limitRetriesByCumulativeDelay 1_000_000 $ fullJitterBackoff 1000 diff --git a/services/cannon/src/Cannon/RabbitMqConsumerApp.hs b/services/cannon/src/Cannon/RabbitMqConsumerApp.hs index 112b7ad8d2a..ede22279071 100644 --- a/services/cannon/src/Cannon/RabbitMqConsumerApp.hs +++ b/services/cannon/src/Cannon/RabbitMqConsumerApp.hs @@ -3,141 +3,68 @@ module Cannon.RabbitMqConsumerApp where import Cannon.App (rejectOnError) -import Cannon.Dict qualified as D -import Cannon.Options +import Cannon.RabbitMq import Cannon.WS hiding (env) import Cassandra as C hiding (batch) import Control.Concurrent.Async -import Control.Concurrent.Timeout import Control.Exception (Handler (..), bracket, catch, catches, throwIO, try) import Control.Lens hiding ((#)) import Control.Monad.Codensity import Data.Aeson hiding (Key) import Data.Id -import Data.List.Extra hiding (delete) -import Data.Timeout (TimeoutUnit (..), (#)) import Imports hiding (min, threadDelay) import Network.AMQP qualified as Q -import Network.AMQP.Extended (withConnection) import Network.WebSockets import Network.WebSockets qualified as WS import System.Logger qualified as Log -import UnliftIO.Async (pooledMapConcurrentlyN_) import Wire.API.Event.WebSocketProtocol import Wire.API.Notification -drainRabbitQueues :: Env -> IO () -drainRabbitQueues e = do - conns <- D.toList e.rabbitConnections - numberOfConns <- fromIntegral <$> D.size e.rabbitConnections - - let opts = e.drainOpts - maxNumberOfBatches = (opts ^. gracePeriodSeconds * 1000) `div` (opts ^. millisecondsBetweenBatches) - computedBatchSize = numberOfConns `div` maxNumberOfBatches - batchSize = max (opts ^. minBatchSize) computedBatchSize - - logDraining e.logg numberOfConns batchSize (opts ^. minBatchSize) computedBatchSize maxNumberOfBatches - - -- Sleeps for the grace period + 1 second. If the sleep completes, it means - -- that draining didn't finish, and we should log that. - timeoutAction <- async $ do - -- Allocate 1 second more than the grace period to allow for overhead of - -- spawning threads. - liftIO $ threadDelay $ ((opts ^. gracePeriodSeconds) # Second + 1 # Second) - logExpired e.logg (opts ^. gracePeriodSeconds) - - for_ (chunksOf (fromIntegral batchSize) conns) $ \batch -> do - -- 16 was chosen with a roll of a fair dice. - void . async $ pooledMapConcurrentlyN_ 16 (uncurry (closeConn e.logg)) batch - liftIO $ threadDelay ((opts ^. millisecondsBetweenBatches) # MilliSecond) - cancel timeoutAction - Log.info e.logg $ Log.msg (Log.val "Draining complete") - where - closeConn :: Log.Logger -> Key -> Q.Connection -> IO () - closeConn l key conn = do - Log.info l $ - Log.msg (Log.val "closing rabbitmq connection") - . Log.field "key" (show key) - Q.closeConnection conn - void $ D.remove key e.rabbitConnections - - logExpired :: Log.Logger -> Word64 -> IO () - logExpired l period = do - Log.err l $ Log.msg (Log.val "Drain grace period expired") . Log.field "gracePeriodSeconds" period - - logDraining :: Log.Logger -> Word64 -> Word64 -> Word64 -> Word64 -> Word64 -> IO () - logDraining l count b min batchSize m = do - Log.info l $ - Log.msg (Log.val "draining all rabbitmq connections") - . Log.field "numberOfConns" count - . Log.field "computedBatchSize" b - . Log.field "minBatchSize" min - . Log.field "batchSize" batchSize - . Log.field "maxNumberOfBatches" m - rabbitMQWebSocketApp :: UserId -> ClientId -> Env -> ServerApp rabbitMQWebSocketApp uid cid e pendingConn = do - wsVar <- newEmptyMVar - msgVar <- newEmptyMVar - - bracket (openWebSocket wsVar) closeWebSocket $ \(wsConn, _) -> + bracket openWebSocket closeWebSocket $ \wsConn -> ( do - sendFullSyncMessageIfNeeded wsVar wsConn uid cid e - sendNotifications wsConn msgVar wsVar + sendFullSyncMessageIfNeeded wsConn uid cid e + sendNotifications wsConn ) `catches` [ handleClientMisbehaving wsConn, - handleWebSocketExceptions wsConn + handleWebSocketExceptions wsConn, + handleOtherExceptions wsConn ] where logClient = Log.field "user" (idToText uid) . Log.field "client" (clientToText cid) - openWebSocket wsVar = do - wsConn <- - acceptRequest pendingConn - `catch` rejectOnError pendingConn - -- start a reader thread for client messages - -- this needs to run asynchronously in order to promptly react to - -- client-side connection termination - a <- async $ forever $ do - catch - ( do - msg <- getClientMessage wsConn - putMVar wsVar (Right msg) - ) - $ \err -> putMVar wsVar (Left err) - pure (wsConn, a) + openWebSocket = + acceptRequest pendingConn + `catch` rejectOnError pendingConn - -- this is only needed in case of asynchronous exceptions - closeWebSocket (wsConn, a) = do - cancel a + closeWebSocket wsConn = do logCloseWebsocket -- ignore any exceptions when sending the close message void . try @SomeException $ WS.sendClose wsConn ("" :: ByteString) - -- Create a rabbitmq consumer that receives messages and saves them into an MVar - createConsumer :: - Q.Channel -> - MVar (Either Q.AMQPException EventData) -> - IO Q.ConsumerTag - createConsumer chan msgVar = do - Q.consumeMsgs chan (clientNotificationQueueName uid cid) Q.Ack $ - \(msg, envelope) -> case eitherDecode @QueuedNotification msg.msgBody of - Left err -> do - logParseError err - -- This message cannot be parsed, make sure it doesn't requeue. There - -- is no need to throw an error and kill the websocket as this is - -- probably caused by a bug or someone messing with RabbitMQ. - -- - -- The bug case is slightly dangerous as it could drop a lot of events - -- en masse, if at some point we decide that Events should not be - -- pushed as JSONs, hopefully we think of the parsing side if/when - -- that happens. - Q.rejectEnv envelope False - Right notif -> - putMVar msgVar . Right $ - EventData notif envelope.envDeliveryTag + getEventData :: RabbitMqChannel -> IO EventData + getEventData chan = do + (msg, envelope) <- getMessage chan + case eitherDecode @QueuedNotification msg.msgBody of + Left err -> do + logParseError err + -- This message cannot be parsed, make sure it doesn't requeue. There + -- is no need to throw an error and kill the websocket as this is + -- probably caused by a bug or someone messing with RabbitMQ. + -- + -- The bug case is slightly dangerous as it could drop a lot of events + -- en masse, if at some point we decide that Events should not be + -- pushed as JSONs, hopefully we think of the parsing side if/when + -- that happens. + Q.rejectEnv envelope False + -- try again + getEventData chan + Right notif -> do + logEvent notif + pure $ EventData notif envelope.envDeliveryTag handleWebSocketExceptions wsConn = Handler $ @@ -173,54 +100,31 @@ rabbitMQWebSocketApp uid cid e pendingConn = do Log.msg (Log.val "Client sent unexpected ack message") . logClient WS.sendCloseCode wsConn 1003 ("unexpected-ack" :: ByteString) - sendNotifications :: - WS.Connection -> - MVar (Either Q.AMQPException EventData) -> - MVar (Either ConnectionException MessageClientToServer) -> - IO () - sendNotifications wsConn msgVar wsVar = lowerCodensity $ do - -- create rabbitmq connection - conn <- Codensity $ withConnection e.logg e.rabbitmq - - -- Store it in the env - let key = mkKeyRabbit uid cid - D.insert key conn e.rabbitConnections - - -- create rabbitmq channel - amqpChan <- Codensity $ bracket (Q.openChannel conn) Q.closeChannel - -- propagate rabbitmq connection failure - lift $ Q.addConnectionClosedHandler conn True $ do - void $ D.remove key e.rabbitConnections - putMVar msgVar $ - Left (Q.ConnectionClosedException Q.Normal "") + handleOtherExceptions wsConn = Handler $ + \(err :: SomeException) -> do + WS.sendCloseCode wsConn 1003 ("internal-error" :: ByteString) + throwIO err - -- register consumer that pushes rabbitmq messages into msgVar - void $ - Codensity $ - bracket - (createConsumer amqpChan msgVar) - (Q.cancelConsumer amqpChan) + sendNotifications :: WS.Connection -> IO () + sendNotifications wsConn = lowerCodensity $ do + let key = mkKeyRabbit uid cid + chan <- createChannel e.pool (clientNotificationQueueName uid cid) key - -- get data from msgVar and push to client let consumeRabbitMq = forever $ do - eventData' <- takeMVar msgVar - either throwIO pure eventData' >>= \eventData -> do - logEvent eventData.event - catch (WS.sendBinaryData wsConn (encode (EventMessage eventData))) $ - \(err :: SomeException) -> do - logSendFailure err - void $ D.remove key e.rabbitConnections - throwIO err + eventData <- getEventData chan + catch (WS.sendBinaryData wsConn (encode (EventMessage eventData))) $ + \(err :: SomeException) -> do + logSendFailure err + throwIO err - -- get ack from wsVar and forward to rabbitmq + -- get ack from websocket and forward to rabbitmq let consumeWebsocket = forever $ do - v <- takeMVar wsVar - either throwIO pure v >>= \case + getClientMessage wsConn >>= \case AckFullSync -> throwIO UnexpectedAck AckMessage ackData -> do logAckReceived ackData - void $ Q.ackMsg amqpChan ackData.deliveryTag ackData.multiple + void $ ackMessage chan ackData.deliveryTag ackData.multiple -- run both loops concurrently, so that -- - notifications are delivered without having to wait for acks @@ -265,16 +169,15 @@ rabbitMQWebSocketApp uid cid e pendingConn = do -- | Check if client has missed messages. If so, send a full synchronisation -- message and wait for the corresponding ack. sendFullSyncMessageIfNeeded :: - MVar (Either ConnectionException MessageClientToServer) -> WS.Connection -> UserId -> ClientId -> Env -> IO () -sendFullSyncMessageIfNeeded wsVar wsConn uid cid env = do +sendFullSyncMessageIfNeeded wsConn uid cid env = do row <- C.runClient env.cassandra do retry x5 $ query1 q (params LocalQuorum (uid, cid)) - for_ row $ \_ -> sendFullSyncMessage uid cid wsVar wsConn env + for_ row $ \_ -> sendFullSyncMessage uid cid wsConn env where q :: PrepQuery R (UserId, ClientId) (Identity (Maybe UserId)) q = @@ -285,15 +188,13 @@ sendFullSyncMessageIfNeeded wsVar wsConn uid cid env = do sendFullSyncMessage :: UserId -> ClientId -> - MVar (Either ConnectionException MessageClientToServer) -> WS.Connection -> Env -> IO () -sendFullSyncMessage uid cid wsVar wsConn env = do +sendFullSyncMessage uid cid wsConn env = do let event = encode EventFullSync WS.sendBinaryData wsConn event - res <- takeMVar wsVar >>= either throwIO pure - case res of + getClientMessage wsConn >>= \case AckMessage _ -> throwIO UnexpectedAck AckFullSync -> C.runClient env.cassandra do diff --git a/services/cannon/src/Cannon/Run.hs b/services/cannon/src/Cannon/Run.hs index 7596b6d2fab..eff9a612f81 100644 --- a/services/cannon/src/Cannon/Run.hs +++ b/services/cannon/src/Cannon/Run.hs @@ -27,8 +27,8 @@ import Cannon.API.Public import Cannon.App (maxPingInterval) import Cannon.Dict qualified as D import Cannon.Options -import Cannon.RabbitMqConsumerApp (drainRabbitQueues) -import Cannon.Types (Cannon, applog, clients, connectionLimit, env, mkEnv, runCannon, runCannonToServant) +import Cannon.RabbitMq +import Cannon.Types hiding (Env) import Cannon.WS hiding (drainOpts, env) import Cassandra.Util (defInitCassandra) import Control.Concurrent @@ -36,7 +36,8 @@ import Control.Concurrent.Async qualified as Async import Control.Exception qualified as E import Control.Exception.Safe (catchAny) import Control.Lens ((^.)) -import Control.Monad.Catch (MonadCatch, finally) +import Control.Monad.Catch (MonadCatch) +import Control.Monad.Codensity import Data.Metrics.Servant import Data.Proxy import Data.Text (pack, strip) @@ -67,26 +68,33 @@ import Wire.OpenTelemetry (withTracer) type CombinedAPI = CannonAPI :<|> Internal.API run :: Opts -> IO () -run o = withTracer \tracer -> do +run o = lowerCodensity $ do + lift $ validateOpts o + tracer <- Codensity withTracer when (o ^. drainOpts . millisecondsBetweenBatches == 0) $ error "drainOpts.millisecondsBetweenBatches must not be set to 0." when (o ^. drainOpts . gracePeriodSeconds == 0) $ error "drainOpts.gracePeriodSeconds must not be set to 0." - ext <- loadExternal - g <- L.mkLogger (o ^. logLevel) (o ^. logNetStrings) (o ^. logFormat) - cassandra <- defInitCassandra (o ^. cassandraOpts) g - e <- - mkEnv ext o cassandra g - <$> D.empty connectionLimit - <*> D.empty connectionLimit - <*> newManager defaultManagerSettings {managerConnCount = connectionLimit} - <*> createSystemRandom - <*> mkClock - <*> pure (o ^. Cannon.Options.rabbitmq) - refreshMetricsThread <- Async.async $ runCannon e refreshMetrics + ext <- lift loadExternal + g <- + Codensity $ + E.bracket + (L.mkLogger (o ^. logLevel) (o ^. logNetStrings) (o ^. logFormat)) + L.close + cassandra <- lift $ defInitCassandra (o ^. cassandraOpts) g + + e <- do + d1 <- D.empty numDictSlices + d2 <- D.empty numDictSlices + man <- lift $ newManager defaultManagerSettings {managerConnCount = 128} + rnd <- lift createSystemRandom + clk <- lift mkClock + mkEnv ext o cassandra g d1 d2 man rnd clk (o ^. Cannon.Options.rabbitmq) + + void $ Codensity $ Async.withAsync $ runCannon e refreshMetrics s <- newSettings $ Server (o ^. cannon . host) (o ^. cannon . port) (applog e) (Just idleTimeout) - otelMiddleWare <- newOpenTelemetryWaiMiddleware + otelMiddleWare <- lift newOpenTelemetryWaiMiddleware let middleware :: Wai.Middleware middleware = versionMiddleware (foldMap expandVersionExp (o ^. disabledAPIVersions)) @@ -101,17 +109,20 @@ run o = withTracer \tracer -> do server = hoistServer (Proxy @CannonAPI) (runCannonToServant e) publicAPIServer :<|> hoistServer (Proxy @Internal.API) (runCannonToServant e) internalServer - tid <- myThreadId - E.handle uncaughtExceptionHandler $ do - void $ installHandler sigTERM (signalHandler (env e) tid) Nothing - void $ installHandler sigINT (signalHandler (env e) tid) Nothing - inSpan tracer "cannon" defaultSpanArguments {kind = Otel.Server} (runSettings s app) `finally` do + tid <- lift myThreadId + + Codensity $ \k -> + inSpan tracer "cannon" defaultSpanArguments {kind = Otel.Server} (k ()) + lift $ + E.handle uncaughtExceptionHandler $ do + let handler = signalHandler (env e) (o ^. drainOpts) tid + void $ installHandler sigTERM handler Nothing + void $ installHandler sigINT handler Nothing -- FUTUREWORK(@akshaymankar, @fisx): we may want to call `runSettingsWithShutdown` here, -- but it's a sensitive change, and it looks like this is closing all the websockets at -- the same time and then calling the drain script. I suspect this might be due to some -- cleanup in wai. this needs to be tested very carefully when touched. - Async.cancel refreshMetricsThread - L.close (applog e) + runSettings s app where idleTimeout = fromIntegral $ maxPingInterval + 3 -- Each cannon instance advertises its own location (ip or dns name) to gundeck. @@ -123,10 +134,10 @@ run o = withTracer \tracer -> do readExternal :: FilePath -> IO ByteString readExternal f = encodeUtf8 . strip . pack <$> Strict.readFile f -signalHandler :: Env -> ThreadId -> Signals.Handler -signalHandler e mainThread = CatchOnce $ do +signalHandler :: Env -> DrainOpts -> ThreadId -> Signals.Handler +signalHandler e opts mainThread = CatchOnce $ do runWS e drain - drainRabbitQueues e + drainRabbitMqPool e.pool opts throwTo mainThread SignalledToExit -- | This is called when the main thread receives the exception generated by diff --git a/services/cannon/src/Cannon/Types.hs b/services/cannon/src/Cannon/Types.hs index ec6e24d729a..81773ed32a1 100644 --- a/services/cannon/src/Cannon/Types.hs +++ b/services/cannon/src/Cannon/Types.hs @@ -20,7 +20,7 @@ module Cannon.Types ( Env (..), Cannon, - connectionLimit, + numDictSlices, mapConcurrentlyCannon, mkEnv, runCannon, @@ -34,12 +34,14 @@ import Bilge (Manager) import Bilge.RPC (HasRequestId (..)) import Cannon.Dict (Dict) import Cannon.Options +import Cannon.RabbitMq import Cannon.WS (Clock, Key, Websocket) import Cannon.WS qualified as WS import Cassandra (ClientState) import Control.Concurrent.Async (mapConcurrently) import Control.Lens ((^.)) import Control.Monad.Catch +import Control.Monad.Codensity import Data.Id import Data.Text.Encoding import Imports @@ -51,8 +53,8 @@ import System.Logger qualified as Logger import System.Logger.Class hiding (info) import System.Random.MWC (GenIO) -connectionLimit :: Int -connectionLimit = 128 +numDictSlices :: Int +numDictSlices = 128 ----------------------------------------------------------------------------- -- Cannon monad @@ -106,10 +108,31 @@ mkEnv :: GenIO -> Clock -> AmqpEndpoint -> - Env -mkEnv external o cs l d conns p g t rabbitmqOpts = - Env o l d conns (RequestId defRequestId) $ - WS.env external (o ^. cannon . port) (encodeUtf8 $ o ^. gundeck . host) (o ^. gundeck . port) l p d conns g t (o ^. drainOpts) rabbitmqOpts cs + Codensity IO Env +mkEnv external o cs l d conns p g t endpoint = do + let poolOpts = + RabbitMqPoolOptions + { endpoint = endpoint, + maxConnections = o ^. rabbitMqMaxConnections, + maxChannels = o ^. rabbitMqMaxChannels + } + pool <- createRabbitMqPool poolOpts l + let wsEnv = + WS.env + external + (o ^. cannon . port) + (encodeUtf8 $ o ^. gundeck . host) + (o ^. gundeck . port) + l + p + d + conns + g + t + (o ^. drainOpts) + cs + pool + pure $ Env o l d conns (RequestId defRequestId) wsEnv runCannon :: Env -> Cannon a -> IO a runCannon e c = runReaderT (unCannon c) e diff --git a/services/cannon/src/Cannon/WS.hs b/services/cannon/src/Cannon/WS.hs index f868a033bb6..1653c82fbd9 100644 --- a/services/cannon/src/Cannon/WS.hs +++ b/services/cannon/src/Cannon/WS.hs @@ -53,6 +53,7 @@ import Bilge.Retry import Cannon.Dict (Dict) import Cannon.Dict qualified as D import Cannon.Options (DrainOpts, gracePeriodSeconds, millisecondsBetweenBatches, minBatchSize) +import Cannon.RabbitMq import Cassandra (ClientState) import Conduit import Control.Concurrent.Timeout @@ -60,6 +61,7 @@ import Control.Lens ((^.)) import Control.Monad.Catch import Control.Retry import Data.Aeson hiding (Error, Key) +import Data.Binary.Builder qualified as B import Data.ByteString.Char8 (pack) import Data.ByteString.Conversion import Data.ByteString.Lazy qualified as L @@ -70,7 +72,6 @@ import Data.Text.Encoding (decodeUtf8) import Data.Timeout (TimeoutUnit (..), (#)) import Imports hiding (threadDelay) import Network.AMQP qualified as Q -import Network.AMQP.Extended import Network.HTTP.Types.Method import Network.HTTP.Types.Status import Network.Wai.Utilities.Error @@ -87,7 +88,7 @@ import Wire.API.Presence newtype Key = Key { _key :: (ByteString, ByteString) } - deriving (Eq, Show, Hashable) + deriving (Eq, Show, Hashable, Ord) mkKey :: UserId -> ConnId -> Key mkKey u c = Key (toByteString' u, fromConnId c) @@ -95,6 +96,9 @@ mkKey u c = Key (toByteString' u, fromConnId c) mkKeyRabbit :: UserId -> ClientId -> Key mkKeyRabbit u c = Key (toByteString' u, toByteString' c) +instance ToByteString Key where + builder = B.fromByteString . key2bytes + key2bytes :: Key -> ByteString key2bytes (Key (u, c)) = u <> "." <> c @@ -154,8 +158,8 @@ data Env = Env rand :: !GenIO, clock :: !Clock, drainOpts :: DrainOpts, - rabbitmq :: !AmqpEndpoint, - cassandra :: ClientState + cassandra :: ClientState, + pool :: RabbitMqPool Key } setRequestId :: RequestId -> Env -> Env @@ -202,8 +206,8 @@ env :: GenIO -> Clock -> DrainOpts -> - AmqpEndpoint -> ClientState -> + RabbitMqPool Key -> Env env leh lp gh gp = Env leh lp (Bilge.host gh . Bilge.port gp $ empty) (RequestId defRequestId)