Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions src/Discord.Net.WebSocket/Audio/DaveSessionManager.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Discord.API.Voice;
using Discord.API.Voice;
using Discord.LibDave;
using Discord.LibDave.Binding;
using Discord.Logging;
Expand Down Expand Up @@ -175,11 +175,27 @@ await _logger.DebugAsync(

await _logger.DebugAsync($"Processed dave MLS proposals, has data?: {result.HasData}");

if (result.HasData) await SendMLSCommitWelcomeAsync(result.ToMemory());
if (result.HasData)
await SendMLSCommitWelcomeAsync(result.ToMemory());
}

private async ValueTask OnDaveMLSAnnounceCommitTransactionAsync(ushort transitionId, ReadOnlyMemory<byte> payload)
{
if (payload.IsEmpty)
{
await _logger.DebugAsync(
$"Payload for commit was empty; transaction id: {transitionId}, getting new key package..."
);

await SendMLSInvalidCommitWelcomeAsync(transitionId);
using var keyPackage = _session.GetMarshalledKeyPackage();
await SendMLSKeyPackageAsync(keyPackage.ToMemory());

await HandleDaveProtocolInitAsync(transitionId);

return;
}

using var commit = _session.ProcessCommit(payload);

if (commit.IsIgnored)
Expand All @@ -188,6 +204,7 @@ private async ValueTask OnDaveMLSAnnounceCommitTransactionAsync(ushort transitio
await _logger.DebugAsync(
$"Commit result ignored, transaction id: {transitionId}, was prepared and removed? {wasRemoved}"
);
UpdateEncryptorRatchet(_session.ProtocolVersion);
return;
}

Expand Down Expand Up @@ -220,20 +237,16 @@ await _logger.DebugAsync(

foreach (var (id, decryptor) in _decryptors)
{
if (id == SelfUserId) continue;
if (id == SelfUserId)
continue;

decryptor.PrepareTransition(_session, id, protocolVersion);
}

UpdateEncryptorRatchet(protocolVersion);

if (transitionId is Dave.InitTransitionId)
{
var ratchet = _session.GetKeyRatchet(SelfUserId);

Encryptor.IsInPassthroughMode = protocolVersion is Dave.DisabledProtocolVersion || ratchet.IsNull;

if (protocolVersion is not Dave.DisabledProtocolVersion && !ratchet.IsNull)
Encryptor.Ratchet = ratchet;

// Streams created before DAVE was initialized lack the DaveDecryptStream layer.
// Rebuild them now that keys are ready.
await _client.RebuildInputStreamsForDaveAsync();
Expand Down Expand Up @@ -262,7 +275,8 @@ public async Task HandleDaveProtocolInitAsync(ushort protocolVersion)

public async Task HandlePrepareEpochAsync(ulong epoch, ushort protocolVersion)
{
if (epoch is not Dave.MLSNewGroupExpectedEpoch) return;
if (epoch is not Dave.MLSNewGroupExpectedEpoch)
return;

await _logger.DebugAsync($"Initializing dave session: epoch {epoch}, protocol version {protocolVersion}");

Expand Down Expand Up @@ -321,6 +335,18 @@ private Task SendDaveProtocolReadyForTransitionAsync(ushort transitionId)
new DaveMLSTransitionParams() { TransitionId = transitionId }
);

private void UpdateEncryptorRatchet(ushort protocolVersion)
{
var ratchet = _session.GetKeyRatchet(SelfUserId);

var isDisabled = protocolVersion is Dave.DisabledProtocolVersion || ratchet.IsNull;

Encryptor.IsInPassthroughMode = isDisabled;

if (!isDisabled)
Encryptor.Ratchet = ratchet;
}

public void Dispose()
{
foreach (var (_, decryptor) in _decryptors)
Expand Down
Loading