diff --git a/src/Discord.Net.WebSocket/Audio/DaveSessionManager.cs b/src/Discord.Net.WebSocket/Audio/DaveSessionManager.cs index 945100901e..baf74f31ab 100644 --- a/src/Discord.Net.WebSocket/Audio/DaveSessionManager.cs +++ b/src/Discord.Net.WebSocket/Audio/DaveSessionManager.cs @@ -1,4 +1,4 @@ -using Discord.API.Voice; +using Discord.API.Voice; using Discord.LibDave; using Discord.LibDave.Binding; using Discord.Logging; @@ -175,11 +175,27 @@ await _logger.DebugAsync( await _logger.DebugAsync($"Processed dave MLS proposals, has data?: {result.HasData}"); - if (result.HasData) await SendMLSCommitWelcomeAsync(result.ToMemory()); + if (result.HasData) + await SendMLSCommitWelcomeAsync(result.ToMemory()); } private async ValueTask OnDaveMLSAnnounceCommitTransactionAsync(ushort transitionId, ReadOnlyMemory payload) { + if (payload.IsEmpty) + { + await _logger.DebugAsync( + $"Payload for commit was empty; transaction id: {transitionId}, getting new key package..." + ); + + await SendMLSInvalidCommitWelcomeAsync(transitionId); + using var keyPackage = _session.GetMarshalledKeyPackage(); + await SendMLSKeyPackageAsync(keyPackage.ToMemory()); + + await HandleDaveProtocolInitAsync(transitionId); + + return; + } + using var commit = _session.ProcessCommit(payload); if (commit.IsIgnored) @@ -188,6 +204,7 @@ private async ValueTask OnDaveMLSAnnounceCommitTransactionAsync(ushort transitio await _logger.DebugAsync( $"Commit result ignored, transaction id: {transitionId}, was prepared and removed? {wasRemoved}" ); + UpdateEncryptorRatchet(_session.ProtocolVersion); return; } @@ -220,20 +237,16 @@ await _logger.DebugAsync( foreach (var (id, decryptor) in _decryptors) { - if (id == SelfUserId) continue; + if (id == SelfUserId) + continue; decryptor.PrepareTransition(_session, id, protocolVersion); } + UpdateEncryptorRatchet(protocolVersion); + if (transitionId is Dave.InitTransitionId) { - var ratchet = _session.GetKeyRatchet(SelfUserId); - - Encryptor.IsInPassthroughMode = protocolVersion is Dave.DisabledProtocolVersion || ratchet.IsNull; - - if (protocolVersion is not Dave.DisabledProtocolVersion && !ratchet.IsNull) - Encryptor.Ratchet = ratchet; - // Streams created before DAVE was initialized lack the DaveDecryptStream layer. // Rebuild them now that keys are ready. await _client.RebuildInputStreamsForDaveAsync(); @@ -262,7 +275,8 @@ public async Task HandleDaveProtocolInitAsync(ushort protocolVersion) public async Task HandlePrepareEpochAsync(ulong epoch, ushort protocolVersion) { - if (epoch is not Dave.MLSNewGroupExpectedEpoch) return; + if (epoch is not Dave.MLSNewGroupExpectedEpoch) + return; await _logger.DebugAsync($"Initializing dave session: epoch {epoch}, protocol version {protocolVersion}"); @@ -321,6 +335,18 @@ private Task SendDaveProtocolReadyForTransitionAsync(ushort transitionId) new DaveMLSTransitionParams() { TransitionId = transitionId } ); + private void UpdateEncryptorRatchet(ushort protocolVersion) + { + var ratchet = _session.GetKeyRatchet(SelfUserId); + + var isDisabled = protocolVersion is Dave.DisabledProtocolVersion || ratchet.IsNull; + + Encryptor.IsInPassthroughMode = isDisabled; + + if (!isDisabled) + Encryptor.Ratchet = ratchet; + } + public void Dispose() { foreach (var (_, decryptor) in _decryptors)