Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 34 additions & 69 deletions src/Discord.Net.Commands/CommandService.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using AsyncKeyedLock;
using Discord.Commands.Builders;
using Discord.Logging;
using System;
Expand All @@ -6,7 +7,6 @@
using System.Collections.Immutable;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;

namespace Discord.Commands
Expand Down Expand Up @@ -46,7 +46,7 @@ public class CommandService : IDisposable
public event Func<Optional<CommandInfo>, ICommandContext, IResult, Task> CommandExecuted { add { _commandExecutedEvent.Add(value); } remove { _commandExecutedEvent.Remove(value); } }
internal readonly AsyncEvent<Func<Optional<CommandInfo>, ICommandContext, IResult, Task>> _commandExecutedEvent = new AsyncEvent<Func<Optional<CommandInfo>, ICommandContext, IResult, Task>>();

private readonly SemaphoreSlim _moduleLock;
private readonly AsyncNonKeyedLocker _moduleLock;
private readonly ConcurrentDictionary<Type, ModuleInfo> _typedModuleDefs;
private readonly ConcurrentDictionary<Type, ConcurrentDictionary<Type, TypeReader>> _typeReaders;
private readonly ConcurrentDictionary<Type, TypeReader> _defaultTypeReaders;
Expand Down Expand Up @@ -105,7 +105,7 @@ public CommandService(CommandServiceConfig config)
_logManager.Message += async msg => await _logEvent.InvokeAsync(msg).ConfigureAwait(false);
_cmdLogger = _logManager.CreateLogger("Command");

_moduleLock = new SemaphoreSlim(1, 1);
_moduleLock = new();
_typedModuleDefs = new ConcurrentDictionary<Type, ModuleInfo>();
_moduleDefs = new HashSet<ModuleInfo>();
_map = new CommandMap(this);
Expand Down Expand Up @@ -137,20 +137,13 @@ public CommandService(CommandServiceConfig config)
#region Modules
public async Task<ModuleInfo> CreateModuleAsync(string primaryAlias, Action<ModuleBuilder> buildFunc)
{
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
{
var builder = new ModuleBuilder(this, null, primaryAlias);
buildFunc(builder);
using var _ = await _moduleLock.LockAsync().ConfigureAwait(false);
var builder = new ModuleBuilder(this, null, primaryAlias);
buildFunc(builder);

var module = builder.Build(this, null);
var module = builder.Build(this, null);

return LoadModuleInternal(module);
}
finally
{
_moduleLock.Release();
}
return LoadModuleInternal(module);
}

/// <summary>
Expand Down Expand Up @@ -191,27 +184,20 @@ public async Task<ModuleInfo> AddModuleAsync(Type type, IServiceProvider service
{
services ??= EmptyServiceProvider.Instance;

await _moduleLock.WaitAsync().ConfigureAwait(false);
try
{
var typeInfo = type.GetTypeInfo();
using var _ = await _moduleLock.LockAsync().ConfigureAwait(false);
var typeInfo = type.GetTypeInfo();

if (_typedModuleDefs.ContainsKey(type))
throw new ArgumentException("This module has already been added.");
if (_typedModuleDefs.ContainsKey(type))
throw new ArgumentException("This module has already been added.");

var module = (await ModuleClassBuilder.BuildAsync(this, services, typeInfo).ConfigureAwait(false)).FirstOrDefault();
var module = (await ModuleClassBuilder.BuildAsync(this, services, typeInfo).ConfigureAwait(false)).FirstOrDefault();

if (module.Value == default(ModuleInfo))
throw new InvalidOperationException($"Could not build the module {type.FullName}, did you pass an invalid type?");
if (module.Value == default(ModuleInfo))
throw new InvalidOperationException($"Could not build the module {type.FullName}, did you pass an invalid type?");

_typedModuleDefs[module.Key] = module.Value;
_typedModuleDefs[module.Key] = module.Value;

return LoadModuleInternal(module.Value);
}
finally
{
_moduleLock.Release();
}
return LoadModuleInternal(module.Value);
}
/// <summary>
/// Add command modules from an <see cref="Assembly"/>.
Expand All @@ -226,24 +212,17 @@ public async Task<IEnumerable<ModuleInfo>> AddModulesAsync(Assembly assembly, IS
{
services ??= EmptyServiceProvider.Instance;

await _moduleLock.WaitAsync().ConfigureAwait(false);
try
{
var types = await ModuleClassBuilder.SearchAsync(assembly, this).ConfigureAwait(false);
var moduleDefs = await ModuleClassBuilder.BuildAsync(types, this, services).ConfigureAwait(false);

foreach (var info in moduleDefs)
{
_typedModuleDefs[info.Key] = info.Value;
LoadModuleInternal(info.Value);
}
using var _ = await _moduleLock.LockAsync().ConfigureAwait(false);
var types = await ModuleClassBuilder.SearchAsync(assembly, this).ConfigureAwait(false);
var moduleDefs = await ModuleClassBuilder.BuildAsync(types, this, services).ConfigureAwait(false);

return moduleDefs.Select(x => x.Value).ToImmutableArray();
}
finally
foreach (var info in moduleDefs)
{
_moduleLock.Release();
_typedModuleDefs[info.Key] = info.Value;
LoadModuleInternal(info.Value);
}

return moduleDefs.Select(x => x.Value).ToImmutableArray();
}
private ModuleInfo LoadModuleInternal(ModuleInfo module)
{
Expand All @@ -267,20 +246,13 @@ private ModuleInfo LoadModuleInternal(ModuleInfo module)
/// </returns>
public async Task<bool> RemoveModuleAsync(ModuleInfo module)
{
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
{
var typeModulePair = _typedModuleDefs.FirstOrDefault(x => x.Value.Equals(module));
using var _ = await _moduleLock.LockAsync().ConfigureAwait(false);
var typeModulePair = _typedModuleDefs.FirstOrDefault(x => x.Value.Equals(module));

if (!typeModulePair.Equals(default(KeyValuePair<Type, ModuleInfo>)))
_typedModuleDefs.TryRemove(typeModulePair.Key, out var _);
if (!typeModulePair.Equals(default(KeyValuePair<Type, ModuleInfo>)))
_typedModuleDefs.TryRemove(typeModulePair.Key, out var _);

return RemoveModuleInternal(module);
}
finally
{
_moduleLock.Release();
}
return RemoveModuleInternal(module);
}
/// <summary>
/// Removes the command module.
Expand All @@ -301,18 +273,11 @@ public async Task<bool> RemoveModuleAsync(ModuleInfo module)
/// </returns>
public async Task<bool> RemoveModuleAsync(Type type)
{
await _moduleLock.WaitAsync().ConfigureAwait(false);
try
{
if (!_typedModuleDefs.TryRemove(type, out var module))
return false;
using var _ = await _moduleLock.LockAsync().ConfigureAwait(false);
if (!_typedModuleDefs.TryRemove(type, out var module))
return false;

return RemoveModuleInternal(module);
}
finally
{
_moduleLock.Release();
}
return RemoveModuleInternal(module);
}
private bool RemoveModuleInternal(ModuleInfo module)
{
Expand Down
1 change: 1 addition & 0 deletions src/Discord.Net.Core/Discord.Net.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
<SymbolPackageFormat>snupkg</SymbolPackageFormat>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="AsyncKeyedLock" Version="8.0.2" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.4" />
<PackageReference Include="System.Linq.AsyncEnumerable" Version="10.0.0" />
<PackageReference Include="IDisposableAnalyzers" Version="4.0.8">
Expand Down
112 changes: 37 additions & 75 deletions src/Discord.Net.Interactions/InteractionService.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using AsyncKeyedLock;
using Discord.Interactions.Builders;
using Discord.Logging;
using Discord.Rest;
Expand Down Expand Up @@ -100,7 +101,7 @@ public event Func<ICommandInfo, IInteractionContext, IResult, Task> InteractionE
private readonly TypeMap<ModalComponentTypeConverter, IComponentInteractionData> _modalInputTypeConverterMap;
private readonly ConcurrentDictionary<Type, IAutocompleteHandler> _autocompleteHandlers = new();
private readonly ConcurrentDictionary<Type, ModalInfo> _modalInfos = new();
private readonly SemaphoreSlim _lock;
private readonly AsyncNonKeyedLocker _lock;
internal readonly Logger _cmdLogger;
internal readonly LogManager _logManager;
internal readonly Func<DiscordRestClient> _getRestClient;
Expand Down Expand Up @@ -165,7 +166,7 @@ private InteractionService(Func<DiscordRestClient> getRestClient, InteractionSer
{
config ??= new InteractionServiceConfig();

_lock = new SemaphoreSlim(1, 1);
_lock = new();
_typedModuleDefs = new ConcurrentDictionary<Type, ModuleInfo>();
_moduleDefs = new HashSet<ModuleInfo>();

Expand Down Expand Up @@ -258,21 +259,14 @@ public async Task<ModuleInfo> CreateModuleAsync(string name, IServiceProvider se
{
services ??= EmptyServiceProvider.Instance;

await _lock.WaitAsync().ConfigureAwait(false);
try
{
var builder = new ModuleBuilder(this, name);
buildFunc(builder);
using var _ = await _lock.LockAsync().ConfigureAwait(false);
var builder = new ModuleBuilder(this, name);
buildFunc(builder);

var moduleInfo = builder.Build(this, services);
LoadModuleInternal(moduleInfo);
var moduleInfo = builder.Build(this, services);
LoadModuleInternal(moduleInfo);

return moduleInfo;
}
finally
{
_lock.Release();
}
return moduleInfo;
}

/// <summary>
Expand All @@ -287,24 +281,16 @@ public async Task<IEnumerable<ModuleInfo>> AddModulesAsync(Assembly assembly, IS
{
services ??= EmptyServiceProvider.Instance;

await _lock.WaitAsync().ConfigureAwait(false);
using var _ = await _lock.LockAsync().ConfigureAwait(false);
var types = await ModuleClassBuilder.SearchAsync(assembly, this);
var moduleDefs = await ModuleClassBuilder.BuildAsync(types, this, services);

try
foreach (var info in moduleDefs)
{
var types = await ModuleClassBuilder.SearchAsync(assembly, this);
var moduleDefs = await ModuleClassBuilder.BuildAsync(types, this, services);

foreach (var info in moduleDefs)
{
_typedModuleDefs[info.Key] = info.Value;
LoadModuleInternal(info.Value);
}
return moduleDefs.Values;
}
finally
{
_lock.Release();
_typedModuleDefs[info.Key] = info.Value;
LoadModuleInternal(info.Value);
}
return moduleDefs.Values;
}

/// <summary>
Expand Down Expand Up @@ -345,32 +331,24 @@ public async Task<ModuleInfo> AddModuleAsync(Type type, IServiceProvider service

services ??= EmptyServiceProvider.Instance;

await _lock.WaitAsync().ConfigureAwait(false);

try
{
var typeInfo = type.GetTypeInfo();
using var _ = await _lock.LockAsync().ConfigureAwait(false);
var typeInfo = type.GetTypeInfo();

if (_typedModuleDefs.ContainsKey(typeInfo))
throw new ArgumentException("Module definition for this type already exists.");
if (_typedModuleDefs.ContainsKey(typeInfo))
throw new ArgumentException("Module definition for this type already exists.");

var moduleDef = (await ModuleClassBuilder.BuildAsync(new List<TypeInfo> { typeInfo }, this, services).ConfigureAwait(false)).FirstOrDefault();
var moduleDef = (await ModuleClassBuilder.BuildAsync(new List<TypeInfo> { typeInfo }, this, services).ConfigureAwait(false)).FirstOrDefault();

if (moduleDef.Value == default)
throw new InvalidOperationException($"Could not build the module {typeInfo.FullName}, did you pass an invalid type?");
if (moduleDef.Value == default)
throw new InvalidOperationException($"Could not build the module {typeInfo.FullName}, did you pass an invalid type?");

if (!_typedModuleDefs.TryAdd(type, moduleDef.Value))
throw new ArgumentException("Module definition for this type already exists.");
if (!_typedModuleDefs.TryAdd(type, moduleDef.Value))
throw new ArgumentException("Module definition for this type already exists.");

_typedModuleDefs[moduleDef.Key] = moduleDef.Value;
LoadModuleInternal(moduleDef.Value);
_typedModuleDefs[moduleDef.Key] = moduleDef.Value;
LoadModuleInternal(moduleDef.Value);

return moduleDef.Value;
}
finally
{
_lock.Release();
}
return moduleDef.Value;
}

/// <summary>
Expand Down Expand Up @@ -641,19 +619,11 @@ public Task<bool> RemoveModuleAsync<T>() =>
/// </returns>
public async Task<bool> RemoveModuleAsync(Type type)
{
await _lock.WaitAsync().ConfigureAwait(false);

try
{
if (!_typedModuleDefs.TryRemove(type, out var module))
return false;
using var _ = await _lock.LockAsync().ConfigureAwait(false);
if (!_typedModuleDefs.TryRemove(type, out var module))
return false;

return RemoveModuleInternal(module);
}
finally
{
_lock.Release();
}
return RemoveModuleInternal(module);
}

/// <summary>
Expand All @@ -666,21 +636,13 @@ public async Task<bool> RemoveModuleAsync(Type type)
/// </returns>
public async Task<bool> RemoveModuleAsync(ModuleInfo module)
{
await _lock.WaitAsync().ConfigureAwait(false);

try
{
var typeModulePair = _typedModuleDefs.FirstOrDefault(x => x.Value.Equals(module));
using var _ = await _lock.LockAsync().ConfigureAwait(false);
var typeModulePair = _typedModuleDefs.FirstOrDefault(x => x.Value.Equals(module));

if (!typeModulePair.Equals(default(KeyValuePair<Type, ModuleInfo>)))
_typedModuleDefs.TryRemove(typeModulePair.Key, out var _);
if (!typeModulePair.Equals(default(KeyValuePair<Type, ModuleInfo>)))
_typedModuleDefs.TryRemove(typeModulePair.Key, out var _);

return RemoveModuleInternal(module);
}
finally
{
_lock.Release();
}
return RemoveModuleInternal(module);
}

/// <summary>
Expand Down
Loading