diff --git a/PowerKit.Tests/ResizableSemaphoreTests.cs b/PowerKit.Tests/ResizableSemaphoreTests.cs new file mode 100644 index 0000000..3e9bbb2 --- /dev/null +++ b/PowerKit.Tests/ResizableSemaphoreTests.cs @@ -0,0 +1,54 @@ +using System; +using System.Threading; +using System.Threading.Tasks; +using FluentAssertions; +using PowerKit; +using Xunit; + +namespace PowerKit.Tests; + +public class ResizableSemaphoreTests +{ + [Fact] + public async Task AcquireAsync_Test() + { + // Arrange + using var semaphore = new ResizableSemaphore { MaxCount = 1 }; + + // Act + using var access1 = await semaphore.AcquireAsync(); + var access2Task = semaphore.AcquireAsync(); + + // Assert + access2Task.IsCompleted.Should().BeFalse(); + access1.Dispose(); + using var access2 = await access2Task; + } + + [Fact] + public async Task AcquireAsync_Cancellation_Test() + { + // Arrange + using var semaphore = new ResizableSemaphore { MaxCount = 1 }; + using var _ = await semaphore.AcquireAsync(); + + // Act & assert + var act = async () => await semaphore.AcquireAsync(new CancellationToken(true)); + await act.Should().ThrowAsync(); + } + + [Fact] + public async Task AcquireAsync_Resized_Test() + { + // Arrange + using var semaphore = new ResizableSemaphore { MaxCount = 1 }; + using var _ = await semaphore.AcquireAsync(); + + // Act + var accessTask = semaphore.AcquireAsync(); + semaphore.MaxCount = 2; + + // Assert + using var access = await accessTask; + } +} diff --git a/PowerKit/ResizableSemaphore.cs b/PowerKit/ResizableSemaphore.cs new file mode 100644 index 0000000..6931066 --- /dev/null +++ b/PowerKit/ResizableSemaphore.cs @@ -0,0 +1,102 @@ +#if NET40_OR_GREATER || NETSTANDARD || NET +#nullable enable +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Threading; +using System.Threading.Tasks; + +namespace PowerKit; + +/// +/// Semaphore whose maximum concurrency count can be adjusted at run time. +/// +#if !POWERKIT_INCLUDE_COVERAGE +[ExcludeFromCodeCoverage] +#endif +internal class ResizableSemaphore : IDisposable +{ + private readonly Lock _lock = new(); + private readonly Queue _waiters = new(); + private readonly CancellationTokenSource _cts = new(); + + private bool _isDisposed; + private int _count; + + /// + /// Gets or sets the maximum number of concurrent accesses. + /// Defaults to . + /// + public int MaxCount + { + get => field; + set + { + using (_lock.EnterScope()) + field = value; + + Refresh(); + } + } = int.MaxValue; + + private void Refresh() + { + using (_lock.EnterScope()) + { + // Provide access to pending waiters, as long as max count allows + while (_count < MaxCount && _waiters.TryDequeue(out var waiter)) + { + // Don't increment the count if the waiter has already been + // completed before (most likely by getting canceled). + if (waiter?.TrySetResult() == true) + _count++; + } + } + } + + private void Release() + { + using (_lock.EnterScope()) + _count--; + + Refresh(); + } + + /// + /// Acquires access to the semaphore, waiting asynchronously if the maximum concurrency count + /// has been reached. Dispose the returned handle to release access. + /// + public async Task AcquireAsync(CancellationToken cancellationToken = default) + { + var waiter = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (_cts.Token.Register(() => waiter.TrySetCanceled(_cts.Token))) + using (cancellationToken.Register(() => waiter.TrySetCanceled(cancellationToken))) + using (_lock.EnterScope()) + { + ObjectDisposedException.ThrowIf(_isDisposed, this); + _waiters.Enqueue(waiter); + } + + Refresh(); + await waiter.Task.ConfigureAwait(false); + + return Disposable.Create(Release); + } + + /// + public void Dispose() + { + using (_lock.EnterScope()) + { + if (_isDisposed) + return; + + _isDisposed = true; + _cts.Cancel(); + } + + _cts.Dispose(); + } +} +#endif