Skip to content

Commit

Permalink
Disconnect bot on voice connection exceptions and run cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
Quahu committed Aug 3, 2024
1 parent f34d4ef commit 37eeab1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 23 deletions.
14 changes: 14 additions & 0 deletions examples/Voice/BasicVoice/AudioModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,18 @@ public async Task<IResult> Skip()

return Response(new LocalInteractionMessageResponse().WithContent("Skipped.").WithIsEphemeral());
}

[SlashCommand("stop")]
[Description("Stops the playback and disconnects the bot from the voice channel.")]
public async Task<IResult> Stop()
{
var player = await _playerService.GetPlayerAsync(Context.GuildId);
if (player == null)
{
return Response("Not playing.");
}

await _playerService.DisposePlayerAsync(Context.GuildId);
return Response("Disconnected.");
}
}
65 changes: 42 additions & 23 deletions src/Disqord.Extensions.Voice/VoiceExtension.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Collections.Generic;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -17,6 +16,7 @@ public class VoiceExtension : DiscordClientExtension
{
private readonly IVoiceConnectionFactory _connectionFactory;

private readonly IThreadSafeDictionary<Snowflake, IVoiceConnection> _pendingConnections;
private readonly IThreadSafeDictionary<Snowflake, VoiceConnectionInfo> _connections;

public VoiceExtension(
Expand All @@ -26,6 +26,7 @@ public VoiceExtension(
{
_connectionFactory = connectionFactory;

_pendingConnections = ThreadSafeDictionary.Monitor.Create<Snowflake, IVoiceConnection>();
_connections = ThreadSafeDictionary.Monitor.Create<Snowflake, VoiceConnectionInfo>();
}

Expand All @@ -40,7 +41,7 @@ protected override ValueTask InitializeAsync(CancellationToken cancellationToken

private Task VoiceServerUpdatedAsync(object? sender, VoiceServerUpdatedEventArgs e)
{
GetConnection(e.GuildId)?.OnVoiceServerUpdate(e.Token, e.Endpoint);
_pendingConnections.GetValueOrDefault(e.GuildId)?.OnVoiceServerUpdate(e.Token, e.Endpoint);
return Task.CompletedTask;
}

Expand All @@ -52,7 +53,7 @@ private Task VoiceStateUpdatedAsync(object? sender, VoiceStateUpdatedEventArgs e
}

var voiceState = e.NewVoiceState;
GetConnection(e.GuildId)?.OnVoiceStateUpdate(voiceState.ChannelId, voiceState.SessionId);
_pendingConnections.GetValueOrDefault(e.GuildId)?.OnVoiceStateUpdate(voiceState.ChannelId, voiceState.SessionId);
return Task.CompletedTask;
}

Expand Down Expand Up @@ -106,20 +107,28 @@ public async ValueTask<IVoiceConnection> ConnectAsync(Snowflake guildId, Snowfla
return new(shard.SetVoiceStateAsync(guildId, channelId, false, true, cancellationToken));
});

var connectionInfo = new VoiceConnectionInfo(connection, Cts.Linked(Client.StoppingToken));
_connections[guildId] = connectionInfo;
try
_pendingConnections[guildId] = connection;
using (var linkedReadyCts = Cts.Linked(cancellationToken, Client.StoppingToken))
{
var readyTask = connection.WaitUntilReadyAsync(cancellationToken);
_ = connection.RunAsync(connectionInfo.Cts.Token);
var readyTask = connection.WaitUntilReadyAsync(linkedReadyCts.Token);

await readyTask.ConfigureAwait(false);
}
catch
{
_connections.Remove(guildId);
await connectionInfo.DisposeAsync();
throw;
var linkedRunCts = Cts.Linked(Client.StoppingToken);
Task runTask;
try
{
runTask = connection.RunAsync(linkedRunCts.Token);
await readyTask.ConfigureAwait(false);
}
catch
{
_pendingConnections.Remove(guildId);
linkedRunCts.Cancel();
linkedRunCts.Dispose();

throw;
}

_connections[guildId] = new VoiceConnectionInfo(connection, runTask, linkedRunCts);
}

return connection;
Expand All @@ -133,33 +142,43 @@ public async ValueTask<IVoiceConnection> ConnectAsync(Snowflake guildId, Snowfla
/// Use <see cref="ConnectAsync"/> to obtain a new connection afterward.
/// </remarks>
/// <param name="guildId"> The ID of the guild. </param>
public ValueTask DisconnectAsync(Snowflake guildId)
public async ValueTask DisconnectAsync(Snowflake guildId)
{
if (!_connections.TryRemove(guildId, out var connectionInfo))
{
return default;
return;
}

return connectionInfo.DisposeAsync();
await connectionInfo.StopAsync().ConfigureAwait(false);
}

private readonly struct VoiceConnectionInfo : IAsyncDisposable
private readonly struct VoiceConnectionInfo
{
public IVoiceConnection Connection { get; }

public Task RunTask { get; }

public Cts Cts { get; }

public VoiceConnectionInfo(IVoiceConnection connection, Cts cts)
public VoiceConnectionInfo(IVoiceConnection connection, Task runTask, Cts cts)
{
Connection = connection;
RunTask = runTask;
Cts = cts;
}

public async ValueTask DisposeAsync()
public async ValueTask StopAsync()
{
Cts.Cancel();

try
{
await RunTask.ConfigureAwait(false);
}
catch { }

Cts.Dispose();
await Connection.DisposeAsync();
await Connection.DisposeAsync().ConfigureAwait(false);
}
}
}
3 changes: 3 additions & 0 deletions src/Disqord.Voice/Default/DefaultVoiceConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ await Gateway.SendAsync(new VoiceGatewayPayloadJsonModel
}
catch (OperationCanceledException ex) when (ex.CancellationToken == linkedCancellationToken && stoppingToken.IsCancellationRequested)
{
await _setVoiceStateDelegate(GuildId, null, default).ConfigureAwait(false);
_readyTcs.Cancel(ex.CancellationToken);
return;
}
Expand Down Expand Up @@ -389,6 +390,8 @@ await Gateway.SendAsync(new VoiceGatewayPayloadJsonModel
}
catch (Exception ex)
{
await _setVoiceStateDelegate(GuildId, null, default).ConfigureAwait(false);

lock (_readyTcs)
{
_readyTcs.Throw(ex);
Expand Down

0 comments on commit 37eeab1

Please sign in to comment.