diff --git a/docs/.vitepress/config.mts b/docs/.vitepress/config.mts index d522fec45..a52196414 100644 --- a/docs/.vitepress/config.mts +++ b/docs/.vitepress/config.mts @@ -113,6 +113,7 @@ const config: UserConfig = { text: 'Message Handlers', link: '/guide/handlers/', items: [ {text: 'Discovery', link: '/guide/handlers/discovery'}, {text: 'Error Handling', link: '/guide/handlers/error-handling'}, + {text: 'Rate Limiting', link: '/guide/handlers/rate-limiting'}, {text: 'Return Values', link: '/guide/handlers/return-values'}, {text: 'Cascading Messages', link: '/guide/handlers/cascading'}, {text: 'Side Effects', link: '/guide/handlers/side-effects'}, diff --git a/docs/guide/handlers/rate-limiting.md b/docs/guide/handlers/rate-limiting.md new file mode 100644 index 000000000..c37b46432 --- /dev/null +++ b/docs/guide/handlers/rate-limiting.md @@ -0,0 +1,59 @@ +# Rate Limiting + +Wolverine can enforce distributed rate limits for message handlers by re-queuing and pausing the listener when limits are exceeded. This is intended for external API usage limits that must be respected across multiple worker nodes. + +## Message Type Rate Limits + +Use `RateLimit` on a message type policy to set a default limit and optional time-of-day overrides: + +```cs +using Wolverine; +using Wolverine.RateLimiting; + +opts.Policies.ForMessagesOfType() + .RateLimit(RateLimit.PerMinute(900), schedule => + { + schedule.TimeZone = TimeZoneInfo.Utc; + schedule.AddWindow(new TimeOnly(8, 0), new TimeOnly(17, 0), RateLimit.PerMinute(400)); + }); +``` + +The middleware enforces the limit before handler execution. If the limit is exceeded, Wolverine re-schedules the message and pauses the listener for the computed delay. + +## Endpoint Rate Limits + +You can also rate limit an entire listening endpoint: + +```cs +using Wolverine; +using Wolverine.RateLimiting; + +opts.RateLimitEndpoint(new Uri("rabbitmq://queue/critical"), RateLimit.PerMinute(400)); +``` + +Endpoint limits take precedence over message type limits when both are configured. + +## Distributed Store + +Rate limiting relies on a shared store. By default, Wolverine registers an in-memory store for tests and local development. For production, register a shared store implementation. + +### SQL Server + +```cs +using Wolverine; +using Wolverine.SqlServer; + +opts.PersistMessagesWithSqlServer(connectionString) + .UseSqlServerRateLimiting(); +``` + +This uses the Wolverine message storage schema by default (same schema as the inbox/outbox tables). + +## Scheduling Requirements + +Rate limiting re-schedules messages through Wolverine's scheduling pipeline. For external listeners, Wolverine requires durable inboxes to ensure rescheduled messages are persisted correctly. + +```cs +opts.ListenToRabbitQueue("critical").UseDurableInbox(); +// or: opts.Policies.UseDurableInboxOnAllListeners(); +``` diff --git a/src/Persistence/SqlServerTests/rate_limiting_storage.cs b/src/Persistence/SqlServerTests/rate_limiting_storage.cs new file mode 100644 index 000000000..48ea09236 --- /dev/null +++ b/src/Persistence/SqlServerTests/rate_limiting_storage.cs @@ -0,0 +1,111 @@ +using IntegrationTests; +using JasperFx.Core; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging.Abstractions; +using Shouldly; +using Weasel.SqlServer; +using JasperFx.Resources; +using Wolverine; +using Wolverine.ComplianceTests.RateLimiting; +using Wolverine.Persistence.Durability; +using Wolverine.RateLimiting; +using Wolverine.RDBMS; +using Wolverine.RDBMS.Sagas; +using Wolverine.SqlServer; +using Wolverine.SqlServer.Persistence; +using Wolverine.SqlServer.RateLimiting; +using Wolverine.SqlServer.Schema; +using Xunit; + +namespace SqlServerTests; + +public class rate_limiting_storage : RateLimitStoreCompliance +{ + private readonly string _schemaName = $"rate_limits_{Guid.NewGuid():N}"; + private IHost? _host; + + protected override async Task BuildStoreAsync() + { + await waitForSqlServerAsync(); + using var conn = new SqlConnection(Servers.SqlServerConnectionString); + await conn.OpenAsync(); + await conn.DropSchemaAsync(_schemaName); + await conn.CloseAsync(); + + _host = await Host.CreateDefaultBuilder() + .UseWolverine(opts => + { + opts.PersistMessagesWithSqlServer(Servers.SqlServerConnectionString, _schemaName) + .UseSqlServerRateLimiting(); + opts.Services.AddResourceSetupOnStartup(); + }).StartAsync(); + + var settings = new DatabaseSettings + { + ConnectionString = Servers.SqlServerConnectionString, + SchemaName = _schemaName + }; + + var persistence = new SqlServerMessageStore(settings, new DurabilitySettings(), + NullLogger.Instance, Array.Empty()); + persistence.AddTable(new RateLimitTable(_schemaName, "wolverine_rate_limits")); + await persistence.RebuildAsync(); + + return new SqlServerRateLimitStore(settings, new SqlServerRateLimitOptions { SchemaName = _schemaName }); + } + + protected override async Task DisposeStoreAsync(IRateLimitStore store) + { + if (_host != null) + { + await _host.StopAsync(); + _host.Dispose(); + } + } + + private static async Task waitForSqlServerAsync() + { + const int maxAttempts = 15; + var delay = TimeSpan.FromSeconds(1); + for (var attempt = 1; attempt <= maxAttempts; attempt++) + { + try + { + await using var conn = new SqlConnection(Servers.SqlServerConnectionString); + await conn.OpenAsync(); + await conn.CloseAsync(); + return; + } + catch (SqlException) when (attempt < maxAttempts) + { + await Task.Delay(delay); + } + } + } + + [Fact] + public async Task creates_rate_limit_table_on_startup() + { + using var conn = new SqlConnection(Servers.SqlServerConnectionString); + await conn.OpenAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = + "SELECT table_schema, table_name FROM information_schema.tables WHERE table_schema = @schema AND table_name = @name"; + cmd.Parameters.Add(new SqlParameter("@schema", _schemaName)); + cmd.Parameters.Add(new SqlParameter("@name", "wolverine_rate_limits")); + + var found = false; + await using (var reader = await cmd.ExecuteReaderAsync()) + { + if (await reader.ReadAsync()) + { + found = true; + } + } + + await conn.CloseAsync(); + + found.ShouldBeTrue(); + } +} diff --git a/src/Persistence/Wolverine.Postgresql/PostgresqlBackedPersistence.cs b/src/Persistence/Wolverine.Postgresql/PostgresqlBackedPersistence.cs index 5aa08b50a..f49d2c800 100644 --- a/src/Persistence/Wolverine.Postgresql/PostgresqlBackedPersistence.cs +++ b/src/Persistence/Wolverine.Postgresql/PostgresqlBackedPersistence.cs @@ -1,4 +1,4 @@ -using System.Data.Common; +using System.Data.Common; using JasperFx; using JasperFx.Core; using JasperFx.MultiTenancy; diff --git a/src/Persistence/Wolverine.SqlServer/Persistence/SqlServerMessageStore.cs b/src/Persistence/Wolverine.SqlServer/Persistence/SqlServerMessageStore.cs index 9da9a3d6e..2c9c520e0 100644 --- a/src/Persistence/Wolverine.SqlServer/Persistence/SqlServerMessageStore.cs +++ b/src/Persistence/Wolverine.SqlServer/Persistence/SqlServerMessageStore.cs @@ -1,4 +1,4 @@ -using System.Data; +using System.Data; using System.Data.Common; using ImTools; using JasperFx; @@ -519,6 +519,11 @@ public override IEnumerable AllObjects() } } + public void AddTable(Table table) + { + _externalTables.Add(table); + } + public override IDatabaseSagaSchema SagaSchemaFor() { if (_sagaStorage.TryFind(typeof(TSaga), out var raw)) diff --git a/src/Persistence/Wolverine.SqlServer/RateLimiting/SqlServerRateLimitStore.cs b/src/Persistence/Wolverine.SqlServer/RateLimiting/SqlServerRateLimitStore.cs new file mode 100644 index 000000000..1c148ef07 --- /dev/null +++ b/src/Persistence/Wolverine.SqlServer/RateLimiting/SqlServerRateLimitStore.cs @@ -0,0 +1,64 @@ +using System.Data; +using Microsoft.Data.SqlClient; +using Weasel.Core; +using Wolverine.RateLimiting; +using Wolverine.RDBMS; +using Wolverine.SqlServer.Schema; + +namespace Wolverine.SqlServer.RateLimiting; + +public sealed class SqlServerRateLimitOptions +{ + public string? SchemaName { get; set; } + public string TableName { get; set; } = "wolverine_rate_limits"; +} + +public sealed class SqlServerRateLimitStore : IRateLimitStore +{ + private readonly string _connectionString; + private readonly string _qualifiedTable; + private readonly SqlServerRateLimitOptions _options; + + public SqlServerRateLimitStore(DatabaseSettings settings, SqlServerRateLimitOptions options) + { + _options = options; + _connectionString = settings.ConnectionString ?? throw new InvalidOperationException("Connection string is required."); + + var schemaName = _options.SchemaName ?? settings.SchemaName ?? "dbo"; + _qualifiedTable = new DbObjectName(schemaName, _options.TableName).QualifiedName; + } + + public async ValueTask TryAcquireAsync(RateLimitStoreRequest request, + CancellationToken cancellationToken) + { + await using var conn = new SqlConnection(_connectionString); + await conn.OpenAsync(cancellationToken).ConfigureAwait(false); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = $@" +MERGE {_qualifiedTable} WITH (HOLDLOCK) AS target +USING (SELECT @key AS {RateLimitTableColumns.Key}, @windowStart AS {RateLimitTableColumns.WindowStart}) AS source +ON target.{RateLimitTableColumns.Key} = source.{RateLimitTableColumns.Key} + AND target.{RateLimitTableColumns.WindowStart} = source.{RateLimitTableColumns.WindowStart} +WHEN MATCHED THEN + UPDATE SET {RateLimitTableColumns.CurrentCount} = target.{RateLimitTableColumns.CurrentCount} + @quantity, + {RateLimitTableColumns.WindowEnd} = @windowEnd, + {RateLimitTableColumns.Limit} = @limitPerWindow +WHEN NOT MATCHED THEN + INSERT ({RateLimitTableColumns.Key}, {RateLimitTableColumns.WindowStart}, {RateLimitTableColumns.WindowEnd}, {RateLimitTableColumns.Limit}, {RateLimitTableColumns.CurrentCount}) + VALUES (@key, @windowStart, @windowEnd, @limitPerWindow, @quantity) +OUTPUT inserted.{RateLimitTableColumns.CurrentCount}; +"; + + cmd.Parameters.Add(new SqlParameter("@key", SqlDbType.VarChar, 500) { Value = request.Key }); + cmd.Parameters.Add(new SqlParameter("@windowStart", SqlDbType.DateTimeOffset) { Value = request.Bucket.WindowStart }); + cmd.Parameters.Add(new SqlParameter("@windowEnd", SqlDbType.DateTimeOffset) { Value = request.Bucket.WindowEnd }); + cmd.Parameters.Add(new SqlParameter("@limitPerWindow", SqlDbType.Int) { Value = request.Bucket.Limit }); + cmd.Parameters.Add(new SqlParameter("@quantity", SqlDbType.Int) { Value = request.Quantity }); + + var current = (int)await cmd.ExecuteScalarAsync(cancellationToken).ConfigureAwait(false); + var allowed = current <= request.Bucket.Limit; + + return new RateLimitStoreResult(allowed, current); + } +} diff --git a/src/Persistence/Wolverine.SqlServer/Schema/RateLimitTable.cs b/src/Persistence/Wolverine.SqlServer/Schema/RateLimitTable.cs new file mode 100644 index 000000000..b07322232 --- /dev/null +++ b/src/Persistence/Wolverine.SqlServer/Schema/RateLimitTable.cs @@ -0,0 +1,25 @@ +using Weasel.Core; +using Weasel.SqlServer.Tables; + +namespace Wolverine.SqlServer.Schema; + +internal static class RateLimitTableColumns +{ + public const string Key = "rate_limit_key"; + public const string WindowStart = "window_start"; + public const string WindowEnd = "window_end"; + public const string Limit = "limit_per_window"; + public const string CurrentCount = "current_count"; +} + +internal class RateLimitTable : Table +{ + public RateLimitTable(string schemaName, string tableName) : base(new DbObjectName(schemaName, tableName)) + { + AddColumn(RateLimitTableColumns.Key, "varchar(500)").NotNull().AsPrimaryKey(); + AddColumn(RateLimitTableColumns.WindowStart).NotNull().AsPrimaryKey(); + AddColumn(RateLimitTableColumns.WindowEnd).NotNull(); + AddColumn(RateLimitTableColumns.Limit).NotNull(); + AddColumn(RateLimitTableColumns.CurrentCount).NotNull(); + } +} diff --git a/src/Persistence/Wolverine.SqlServer/SqlServerBackedPersistence.cs b/src/Persistence/Wolverine.SqlServer/SqlServerBackedPersistence.cs index 837f5f065..d97682f4d 100644 --- a/src/Persistence/Wolverine.SqlServer/SqlServerBackedPersistence.cs +++ b/src/Persistence/Wolverine.SqlServer/SqlServerBackedPersistence.cs @@ -1,4 +1,4 @@ -using System.Data.Common; +using System.Data.Common; using JasperFx; using JasperFx.CodeGeneration.Model; using JasperFx.Core; @@ -112,12 +112,20 @@ public interface ISqlServerBackedPersistence internal class SqlServerBackedPersistence : IWolverineExtension, ISqlServerBackedPersistence { private readonly WolverineOptions _options; + private readonly List> _storeConfigurations = new(); public SqlServerBackedPersistence(WolverineOptions options) { _options = options; } + internal WolverineOptions Options => _options; + + internal void AddStoreConfiguration(Action configuration) + { + _storeConfigurations.Add(configuration); + } + public string? ConnectionString { get; set; } public string EnvelopeStorageSchemaName { get; set; } = "wolverine"; @@ -194,6 +202,7 @@ public IMessageStore BuildMessageStore(IWolverineRuntime runtime) { var defaultStore = new SqlServerMessageStore(settings, runtime.DurabilitySettings, logger, sagaTables); + applyStoreConfigurations(defaultStore); ConnectionStringTenancy = new MasterTenantSource(defaultStore, runtime.Options); @@ -205,6 +214,7 @@ public IMessageStore BuildMessageStore(IWolverineRuntime runtime) { var defaultStore = new SqlServerMessageStore(settings, runtime.DurabilitySettings, logger, sagaTables); + applyStoreConfigurations(defaultStore); return new MultiTenantedMessageStore(defaultStore, runtime, new SqlServerTenantedMessageStore(runtime, this, sagaTables){DataSource = ConnectionStringTenancy}); @@ -212,8 +222,23 @@ public IMessageStore BuildMessageStore(IWolverineRuntime runtime) settings.Role = Role; - return new SqlServerMessageStore(settings, runtime.DurabilitySettings, + var store = new SqlServerMessageStore(settings, runtime.DurabilitySettings, logger, sagaTables); + applyStoreConfigurations(store); + return store; + } + + internal void ApplyStoreConfigurations(SqlServerMessageStore store) + { + applyStoreConfigurations(store); + } + + private void applyStoreConfigurations(SqlServerMessageStore store) + { + foreach (var configuration in _storeConfigurations) + { + configuration(store); + } } private DatabaseSettings buildMainDatabaseSettings() diff --git a/src/Persistence/Wolverine.SqlServer/SqlServerConfigurationExtensions.cs b/src/Persistence/Wolverine.SqlServer/SqlServerConfigurationExtensions.cs index 8b2e435ec..9d5e6986b 100644 --- a/src/Persistence/Wolverine.SqlServer/SqlServerConfigurationExtensions.cs +++ b/src/Persistence/Wolverine.SqlServer/SqlServerConfigurationExtensions.cs @@ -1,12 +1,16 @@ -using JasperFx.Core; +using JasperFx.Core; using JasperFx.Core.Reflection; using Microsoft.Data.SqlClient; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; using Weasel.Core.Migrations; using Wolverine.Configuration; using Wolverine.ErrorHandling; using Wolverine.Persistence.Durability; +using Wolverine.RateLimiting; using Wolverine.RDBMS; +using Wolverine.SqlServer.RateLimiting; +using Wolverine.SqlServer.Schema; using Wolverine.SqlServer.Transport; namespace Wolverine.SqlServer; @@ -46,6 +50,29 @@ public static ISqlServerBackedPersistence PersistMessagesWithSqlServer(this Wolv return extension; } + /// + /// Register SQL Server backed rate limiting storage + /// + public static ISqlServerBackedPersistence UseSqlServerRateLimiting(this ISqlServerBackedPersistence persistence, + Action? configure = null) + { + var options = new SqlServerRateLimitOptions(); + configure?.Invoke(options); + + var concrete = persistence.As(); + var schemaName = options.SchemaName ?? concrete.EnvelopeStorageSchemaName; + + concrete.AddStoreConfiguration(store => + { + store.AddTable(new RateLimitTable(schemaName, options.TableName)); + }); + + concrete.Options.Services.TryAddSingleton(options); + concrete.Options.Services.TryAddSingleton(); + + return persistence; + } + /// /// Register Sql Server backed message persistence *and* the Sql Server messaging transport /// diff --git a/src/Persistence/Wolverine.SqlServer/SqlServerTenantedMessageStore.cs b/src/Persistence/Wolverine.SqlServer/SqlServerTenantedMessageStore.cs index d0f81e26c..660c9a3e6 100644 --- a/src/Persistence/Wolverine.SqlServer/SqlServerTenantedMessageStore.cs +++ b/src/Persistence/Wolverine.SqlServer/SqlServerTenantedMessageStore.cs @@ -74,6 +74,7 @@ private SqlServerMessageStore buildTenantStoreForConnectionString(string connect store = new SqlServerMessageStore(settings, _runtime.Options.Durability, _runtime.LoggerFactory.CreateLogger(), _sagaTables); + _persistence.ApplyStoreConfigurations(store); return store; } diff --git a/src/Testing/CoreTests/RateLimiting/rate_limiting_configuration.cs b/src/Testing/CoreTests/RateLimiting/rate_limiting_configuration.cs new file mode 100644 index 000000000..b75c7f83c --- /dev/null +++ b/src/Testing/CoreTests/RateLimiting/rate_limiting_configuration.cs @@ -0,0 +1,45 @@ +using Microsoft.Extensions.Logging.Abstractions; +using Shouldly; +using Wolverine; +using Wolverine.RateLimiting; +using Xunit; + +namespace CoreTests.RateLimiting; + +public class rate_limiting_configuration +{ + [Fact] + public void schedule_selects_matching_window() + { + var schedule = new RateLimitSchedule(RateLimit.PerMinute(900)) + { + TimeZone = TimeZoneInfo.Utc + }; + + schedule.AddWindow(new TimeOnly(8, 0), new TimeOnly(17, 0), RateLimit.PerMinute(400)); + + schedule.Resolve(new DateTimeOffset(2024, 1, 1, 9, 0, 0, TimeSpan.Zero)) + .Permits.ShouldBe(400); + + schedule.Resolve(new DateTimeOffset(2024, 1, 1, 18, 0, 0, TimeSpan.Zero)) + .Permits.ShouldBe(900); + } + + [Fact] + public async Task rate_limiter_denies_after_limit() + { + var options = new WolverineOptions(); + options.Policies.ForMessagesOfType() + .RateLimit(RateLimit.PerMinute(2)); + + var limiter = new RateLimiter(new InMemoryRateLimitStore(), options, NullLogger.Instance); + var envelope = new Envelope(new RateLimitedMessage()); + var now = new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); + + (await limiter.CheckAsync(envelope, now, CancellationToken.None)).Allowed.ShouldBeTrue(); + (await limiter.CheckAsync(envelope, now, CancellationToken.None)).Allowed.ShouldBeTrue(); + (await limiter.CheckAsync(envelope, now, CancellationToken.None)).Allowed.ShouldBeFalse(); + } + + private sealed record RateLimitedMessage; +} diff --git a/src/Testing/CoreTests/RateLimiting/rate_limiting_core_tests.cs b/src/Testing/CoreTests/RateLimiting/rate_limiting_core_tests.cs new file mode 100644 index 000000000..e12597984 --- /dev/null +++ b/src/Testing/CoreTests/RateLimiting/rate_limiting_core_tests.cs @@ -0,0 +1,267 @@ +using JasperFx.Core; +using Microsoft.Extensions.Logging.Abstractions; +using NSubstitute; +using Shouldly; +using System.Linq; +using Wolverine; +using Wolverine.Configuration; +using Wolverine.RateLimiting; +using Wolverine.Runtime; +using Wolverine.Transports; +using Wolverine.Transports.Local; +using CoreTests.Runtime; +using Xunit; + +namespace CoreTests.RateLimiting; + +public class rate_limiting_core_tests +{ + [Fact] + public void rate_limit_window_matches_normal_and_wrapped_ranges() + { + var normal = new RateLimitWindow(new TimeOnly(8, 0), new TimeOnly(17, 0), RateLimit.PerMinute(10)); + normal.Matches(new TimeOnly(9, 0)).ShouldBeTrue(); + normal.Matches(new TimeOnly(18, 0)).ShouldBeFalse(); + + var wrapped = new RateLimitWindow(new TimeOnly(22, 0), new TimeOnly(2, 0), RateLimit.PerMinute(10)); + wrapped.Matches(new TimeOnly(23, 0)).ShouldBeTrue(); + wrapped.Matches(new TimeOnly(1, 0)).ShouldBeTrue(); + wrapped.Matches(new TimeOnly(12, 0)).ShouldBeFalse(); + } + + [Fact] + public void rate_limit_window_start_equals_end_matches_all_times() + { + var window = new RateLimitWindow(new TimeOnly(0, 0), new TimeOnly(0, 0), RateLimit.PerMinute(1)); + window.Matches(new TimeOnly(0, 0)).ShouldBeTrue(); + window.Matches(new TimeOnly(12, 0)).ShouldBeTrue(); + } + + [Fact] + public void schedule_throws_on_invalid_default_limit() + { + var schedule = new RateLimitSchedule(new RateLimit(0, 1.Minutes())); + Should.Throw(() => + schedule.Resolve(new DateTimeOffset(2024, 1, 1, 9, 0, 0, TimeSpan.Zero))); + } + + [Fact] + public void schedule_throws_on_invalid_window_limit() + { + var schedule = new RateLimitSchedule(new RateLimit(1, 1.Minutes())) + { + TimeZone = TimeZoneInfo.Utc + }; + schedule.AddWindow(new TimeOnly(8, 0), new TimeOnly(9, 0), new RateLimit(0, 1.Minutes())); + + Should.Throw(() => + schedule.Resolve(new DateTimeOffset(2024, 1, 1, 8, 30, 0, TimeSpan.Zero))); + } + + [Fact] + public void settings_registry_matches_assignable_message_type() + { + var registry = new RateLimitSettingsRegistry(); + var schedule = new RateLimitSchedule(RateLimit.PerMinute(5)); + registry.RegisterMessageType(typeof(IBaseMessage), new RateLimitSettings("base", schedule)); + + registry.TryFindForMessageType(typeof(DerivedMessage), out var settings).ShouldBeTrue(); + settings.Key.ShouldBe("base"); + } + + [Fact] + public void settings_registry_matches_endpoint() + { + var registry = new RateLimitSettingsRegistry(); + var endpoint = new Uri("local://rate-limited"); + var schedule = new RateLimitSchedule(RateLimit.PerMinute(5)); + registry.RegisterEndpoint(endpoint, new RateLimitSettings("endpoint", schedule)); + + registry.TryFindForEndpoint(endpoint, out var settings).ShouldBeTrue(); + settings.Key.ShouldBe("endpoint"); + } + + [Fact] + public async Task in_memory_store_enforces_limit_per_window() + { + var store = new InMemoryRateLimitStore(); + var limit = new RateLimit(2, 1.Minutes()); + var now = new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); + var bucket = RateLimitBucket.For(limit, now); + + (await store.TryAcquireAsync(new RateLimitStoreRequest("key", bucket, 1, now), CancellationToken.None)) + .Allowed.ShouldBeTrue(); + (await store.TryAcquireAsync(new RateLimitStoreRequest("key", bucket, 1, now), CancellationToken.None)) + .Allowed.ShouldBeTrue(); + (await store.TryAcquireAsync(new RateLimitStoreRequest("key", bucket, 1, now), CancellationToken.None)) + .Allowed.ShouldBeFalse(); + + var later = now.AddMinutes(1).AddSeconds(1); + var nextBucket = RateLimitBucket.For(limit, later); + (await store.TryAcquireAsync(new RateLimitStoreRequest("key", nextBucket, 1, later), CancellationToken.None)) + .Allowed.ShouldBeTrue(); + } + + [Fact] + public async Task rate_limiter_prefers_listener_endpoint_over_message_type() + { + var options = new WolverineOptions(); + options.Policies.ForMessagesOfType() + .RateLimit("message", RateLimit.PerMinute(1)); + + var listenerUri = new Uri("stub://listener"); + options.RateLimitEndpoint(listenerUri, RateLimit.PerMinute(1), key: "endpoint"); + + var store = new CapturingRateLimitStore(); + var limiter = new RateLimiter(store, options, NullLogger.Instance); + + var envelope = new Envelope(new DerivedMessage()) + { + Listener = new FakeListener(listenerUri) + }; + + await limiter.CheckAsync(envelope, new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero), + CancellationToken.None); + + store.LastRequest!.Key.ShouldBe("endpoint"); + } + + [Fact] + public async Task rate_limiter_uses_destination_when_listener_missing() + { + var options = new WolverineOptions(); + options.Policies.ForMessagesOfType() + .RateLimit("message", RateLimit.PerMinute(1)); + + var destination = new Uri("stub://destination"); + options.RateLimitEndpoint(destination, RateLimit.PerMinute(1), key: "endpoint"); + + var store = new CapturingRateLimitStore(); + var limiter = new RateLimiter(store, options, NullLogger.Instance); + + var envelope = new Envelope(new DerivedMessage()) + { + Destination = destination + }; + + await limiter.CheckAsync(envelope, new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero), + CancellationToken.None); + + store.LastRequest!.Key.ShouldBe("endpoint"); + } + + [Fact] + public async Task middleware_throws_rate_limit_exception_with_expected_data() + { + var options = new WolverineOptions(); + options.Policies.ForMessagesOfType() + .RateLimit("key", RateLimit.PerMinute(1)); + + var limiter = new RateLimiter(new DenyRateLimitStore(), options, NullLogger.Instance); + var middleware = new RateLimitMiddleware(limiter); + + var envelope = new Envelope(new DerivedMessage()); + var ex = await Should.ThrowAsync(() => + middleware.BeforeAsync(envelope, CancellationToken.None)); + + ex.Key.ShouldBe("key"); + ex.Limit.Permits.ShouldBe(1); + ex.RetryAfter.ShouldBeGreaterThan(TimeSpan.Zero); + ex.RetryAfter.ShouldBeLessThanOrEqualTo(1.Minutes()); + } + + [Fact] + public async Task continuation_reschedules_and_pauses_listener() + { + var runtime = new MockWolverineRuntime(); + var listenerUri = new Uri("stub://listener"); + var listener = new FakeListener(listenerUri); + var agent = Substitute.For(); + agent.Endpoint.Returns(new LocalQueue("rate-limited")); + agent.PauseAsync(Arg.Any()).Returns(ValueTask.CompletedTask); + + runtime.Endpoints.FindListeningAgent(listenerUri).Returns(agent); + + var envelope = new Envelope(new DerivedMessage()) + { + Listener = listener + }; + + var lifecycle = Substitute.For(); + lifecycle.Envelope.Returns(envelope); + lifecycle.ReScheduleAsync(Arg.Any()).Returns(Task.CompletedTask); + + var now = new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); + var pause = 5.Seconds(); + + var continuation = new RateLimitContinuation(pause); + await continuation.ExecuteAsync(lifecycle, runtime, now, null); + + await lifecycle.Received(1).ReScheduleAsync(now.Add(pause)); + await agent.Received(1).PauseAsync(pause); + } + + [Fact] + public void continuation_source_describes_and_builds_from_exception() + { + var source = new RateLimitContinuationSource(); + source.Description.ShouldContain("Rate limit"); + + var envelope = new Envelope(new DerivedMessage()); + var exception = new RateLimitExceededException("key", RateLimit.PerMinute(1), 10.Seconds()); + var continuation = source.Build(exception, envelope).ShouldBeOfType(); + continuation.PauseTime.ShouldBe(10.Seconds()); + } + + [Fact] + public void options_registers_rate_limiting_services_once() + { + var options = new WolverineOptions(); + options.Policies.ForMessagesOfType() + .RateLimit(RateLimit.PerMinute(1)); + + options.RateLimitEndpoint(new Uri("local://rate-limited"), RateLimit.PerMinute(1)); + + options.Services.Count(x => x.ServiceType == typeof(IRateLimitStore)).ShouldBe(1); + options.Services.Count(x => x.ServiceType == typeof(RateLimiter)).ShouldBe(1); + } + + private interface IBaseMessage; + private sealed record DerivedMessage : IBaseMessage; + + private sealed class CapturingRateLimitStore : IRateLimitStore + { + public RateLimitStoreRequest? LastRequest { get; private set; } + + public ValueTask TryAcquireAsync(RateLimitStoreRequest request, + CancellationToken cancellationToken) + { + LastRequest = request; + return new ValueTask(new RateLimitStoreResult(true, 1)); + } + } + + private sealed class DenyRateLimitStore : IRateLimitStore + { + public ValueTask TryAcquireAsync(RateLimitStoreRequest request, + CancellationToken cancellationToken) + { + return new ValueTask(new RateLimitStoreResult(false, request.Bucket.Limit + 1)); + } + } + + private sealed class FakeListener : IListener + { + public FakeListener(Uri address) + { + Address = address; + } + + public Uri Address { get; } + public IHandlerPipeline? Pipeline => null; + public ValueTask CompleteAsync(Envelope envelope) => ValueTask.CompletedTask; + public ValueTask DeferAsync(Envelope envelope) => ValueTask.CompletedTask; + public ValueTask StopAsync() => ValueTask.CompletedTask; + public ValueTask DisposeAsync() => ValueTask.CompletedTask; + } +} diff --git a/src/Testing/Wolverine.ComplianceTests/RateLimiting/RateLimitStoreCompliance.cs b/src/Testing/Wolverine.ComplianceTests/RateLimiting/RateLimitStoreCompliance.cs new file mode 100644 index 000000000..69408e27b --- /dev/null +++ b/src/Testing/Wolverine.ComplianceTests/RateLimiting/RateLimitStoreCompliance.cs @@ -0,0 +1,91 @@ +using JasperFx.Core; +using Shouldly; +using Wolverine.RateLimiting; +using Xunit; + +namespace Wolverine.ComplianceTests.RateLimiting; + +public abstract class RateLimitStoreCompliance : IAsyncLifetime +{ + protected IRateLimitStore Store { get; private set; } = null!; + + protected abstract Task BuildStoreAsync(); + + protected virtual Task InitializeStoreAsync(IRateLimitStore store) + { + return Task.CompletedTask; + } + + protected virtual Task DisposeStoreAsync(IRateLimitStore store) + { + return Task.CompletedTask; + } + + public async Task InitializeAsync() + { + Store = await BuildStoreAsync(); + await InitializeStoreAsync(Store); + } + + public async Task DisposeAsync() + { + await DisposeStoreAsync(Store); + + switch (Store) + { + case IAsyncDisposable asyncDisposable: + await asyncDisposable.DisposeAsync(); + break; + case IDisposable disposable: + disposable.Dispose(); + break; + } + } + + [Fact] + public async Task allows_up_to_limit_then_denies() + { + var now = new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); + var limit = new RateLimit(2, 1.Minutes()); + var bucket = RateLimitBucket.For(limit, now); + var key = $"key-{Guid.NewGuid():N}"; + + (await Store.TryAcquireAsync(new RateLimitStoreRequest(key, bucket, 1, now), CancellationToken.None)) + .Allowed.ShouldBeTrue(); + (await Store.TryAcquireAsync(new RateLimitStoreRequest(key, bucket, 1, now), CancellationToken.None)) + .Allowed.ShouldBeTrue(); + (await Store.TryAcquireAsync(new RateLimitStoreRequest(key, bucket, 1, now), CancellationToken.None)) + .Allowed.ShouldBeFalse(); + } + + [Fact] + public async Task allows_again_in_next_bucket() + { + var now = new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); + var limit = new RateLimit(1, 1.Minutes()); + var key = $"key-{Guid.NewGuid():N}"; + + var bucket = RateLimitBucket.For(limit, now); + (await Store.TryAcquireAsync(new RateLimitStoreRequest(key, bucket, 1, now), CancellationToken.None)) + .Allowed.ShouldBeTrue(); + + var later = now.AddMinutes(1).AddSeconds(1); + var nextBucket = RateLimitBucket.For(limit, later); + (await Store.TryAcquireAsync(new RateLimitStoreRequest(key, nextBucket, 1, later), CancellationToken.None)) + .Allowed.ShouldBeTrue(); + } + + [Fact] + public async Task honors_quantity() + { + var now = new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero); + var limit = new RateLimit(3, 1.Minutes()); + var bucket = RateLimitBucket.For(limit, now); + var key = $"key-{Guid.NewGuid():N}"; + + (await Store.TryAcquireAsync(new RateLimitStoreRequest(key, bucket, 2, now), CancellationToken.None)) + .Allowed.ShouldBeTrue(); + (await Store.TryAcquireAsync(new RateLimitStoreRequest(key, bucket, 2, now), CancellationToken.None)) + .Allowed.ShouldBeFalse(); + } +} diff --git a/src/Transports/RabbitMQ/Wolverine.RabbitMQ.Tests/rate_limiting_end_to_end.cs b/src/Transports/RabbitMQ/Wolverine.RabbitMQ.Tests/rate_limiting_end_to_end.cs new file mode 100644 index 000000000..6afa8a401 --- /dev/null +++ b/src/Transports/RabbitMQ/Wolverine.RabbitMQ.Tests/rate_limiting_end_to_end.cs @@ -0,0 +1,152 @@ +using IntegrationTests; +using JasperFx.Core; +using JasperFx.Resources; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Shouldly; +using Wolverine; +using Wolverine.Postgresql; +using Wolverine.RateLimiting; +using Xunit; + +namespace Wolverine.RabbitMQ.Tests; + +public class rate_limiting_end_to_end +{ + [Fact] + public async Task rate_limited_messages_are_delayed_over_rabbitmq() + { + var queueName = $"rate-limit-{Guid.NewGuid():N}"; + var schemaName = $"rate_limit_{Guid.NewGuid():N}"; + var tracker = new RateLimitTracker(); + var window = 1.Seconds(); + + IHost? publisher = null; + IHost? receiver = null; + + try + { + publisher = await Host.CreateDefaultBuilder() + .UseWolverine(opts => + { + opts.UseRabbitMq().DisableDeadLetterQueueing().AutoProvision().AutoPurgeOnStartup(); + opts.PublishAllMessages().ToRabbitQueue(queueName); + opts.Services.AddResourceSetupOnStartup(StartupAction.ResetState); + }).StartAsync(); + + receiver = await Host.CreateDefaultBuilder() + .UseWolverine(opts => + { + opts.ApplicationAssembly = typeof(rate_limiting_end_to_end).Assembly; + opts.Services.AddSingleton(tracker); + + opts.PersistMessagesWithPostgresql(Servers.PostgresConnectionString, schemaName); + opts.UseRabbitMq().DisableDeadLetterQueueing().AutoProvision().AutoPurgeOnStartup(); + opts.ListenToRabbitQueue(queueName).UseDurableInbox().Sequential(); + + opts.Policies.ForMessagesOfType() + .RateLimit("rabbitmq-rate-limit", new RateLimit(1, window)); + + opts.Services.AddResourceSetupOnStartup(StartupAction.ResetState); + }).StartAsync(); + + await publisher.ResetResourceState(); + await receiver.ResetResourceState(); + await alignToWindowStart(window); + + var bus = publisher.Services.GetRequiredService(); + await bus.PublishAsync(new RateLimitedMessage()); + await bus.PublishAsync(new RateLimitedMessage()); + + var first = await tracker.FirstHandled.Task.WaitAsync(10.Seconds()); + var second = await tracker.SecondHandled.Task.WaitAsync(10.Seconds()); + + (second - first).ShouldBeGreaterThanOrEqualTo(700.Milliseconds()); + } + finally + { + if (receiver != null) + { + await safeStopAsync(receiver); + } + + if (publisher != null) + { + await safeStopAsync(publisher); + } + } + } + + private static async Task alignToWindowStart(TimeSpan window) + { + var windowTicks = window.Ticks; + var thresholdTicks = 50.Milliseconds().Ticks; + + for (var attempt = 0; attempt < 200; attempt++) + { + if (DateTimeOffset.UtcNow.Ticks % windowTicks < thresholdTicks) + { + return; + } + + await Task.Delay(10.Milliseconds()); + } + + throw new TimeoutException("Could not align to rate limit window start."); + } + + private static async Task safeStopAsync(IHost host) + { + try + { + await host.StopAsync(); + } + catch (OperationCanceledException) + { + } + + try + { + host.Dispose(); + } + catch (OperationCanceledException) + { + } + } +} + +public record RateLimitedMessage; + +public class RateLimitTracker +{ + private int _count; + private readonly object _lock = new(); + + public TaskCompletionSource FirstHandled { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously); + public TaskCompletionSource SecondHandled { get; } = new(TaskCreationOptions.RunContinuationsAsynchronously); + + public void RecordHandled() + { + var now = DateTimeOffset.UtcNow; + lock (_lock) + { + _count++; + if (_count == 1) + { + FirstHandled.TrySetResult(now); + } + else if (_count == 2) + { + SecondHandled.TrySetResult(now); + } + } + } +} + +public static class RateLimitedMessageHandler +{ + public static void Handle(RateLimitedMessage message, RateLimitTracker tracker) + { + tracker.RecordHandled(); + } +} diff --git a/src/Transports/Redis/Wolverine.Redis.Tests/rate_limiting_end_to_end.cs b/src/Transports/Redis/Wolverine.Redis.Tests/rate_limiting_end_to_end.cs new file mode 100644 index 000000000..035eb94d8 --- /dev/null +++ b/src/Transports/Redis/Wolverine.Redis.Tests/rate_limiting_end_to_end.cs @@ -0,0 +1,126 @@ +using JasperFx.Core; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Shouldly; +using Wolverine.Configuration; +using Wolverine.RateLimiting; +using Xunit; + +namespace Wolverine.Redis.Tests; + +public class rate_limiting_end_to_end +{ + [Fact] + public async Task rate_limited_messages_are_delayed_with_native_scheduling() + { + var streamKey = $"rate-limit-{Guid.NewGuid():N}"; + var groupName = $"rate-limit-group-{Guid.NewGuid():N}"; + var window = 2.Seconds(); + var limit = new RateLimit(1, window); + var tracker = new RedisRateLimitTracker(expectedCount: 2); + var endpointUri = new Uri($"redis://stream/0/{streamKey}"); + + using var host = await Host.CreateDefaultBuilder() + .UseWolverine(opts => + { + opts.ApplicationAssembly = typeof(rate_limiting_end_to_end).Assembly; + opts.Services.AddSingleton(tracker); + + opts.UseRedisTransport("localhost:6379").AutoProvision(); + opts.PublishAllMessages().ToRedisStream(streamKey); + opts.ListenToRedisStream(streamKey, groupName).StartFromBeginning(); + + opts.RateLimitEndpoint(endpointUri, limit); + }).StartAsync(); + + await Task.Delay(250.Milliseconds()); + await waitForNextBucketStartAsync(limit); + + var bus = host.MessageBus(); + await bus.PublishAsync(new RedisRateLimitedMessage(Guid.NewGuid().ToString())); + await bus.PublishAsync(new RedisRateLimitedMessage(Guid.NewGuid().ToString())); + + await tracker.WaitForHandledAsync(15.Seconds()); + + var handled = tracker.HandledTimes; + handled.Count.ShouldBeGreaterThanOrEqualTo(2); + + var firstBucket = RateLimitBucket.For(limit, handled[0]); + var secondBucket = RateLimitBucket.For(limit, handled[1]); + firstBucket.WindowStart.ShouldNotBe(secondBucket.WindowStart); + } + + private static async Task waitForNextBucketStartAsync(RateLimit limit) + { + var now = DateTimeOffset.UtcNow; + var bucket = RateLimitBucket.For(limit, now); + var delay = bucket.WindowEnd - now + 50.Milliseconds(); + if (delay < TimeSpan.Zero) + { + delay = 50.Milliseconds(); + } + + await Task.Delay(delay); + } +} + +public record RedisRateLimitedMessage(string Id); + +public class RedisRateLimitedMessageHandler +{ + private readonly RedisRateLimitTracker _tracker; + + public RedisRateLimitedMessageHandler(RedisRateLimitTracker tracker) + { + _tracker = tracker; + } + + public Task Handle(RedisRateLimitedMessage message, CancellationToken cancellationToken) + { + _tracker.RecordHandled(); + return Task.CompletedTask; + } +} + +public class RedisRateLimitTracker +{ + private readonly int _expectedCount; + private readonly TaskCompletionSource _completion = + new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly List _handledTimes = []; + private readonly object _lock = new(); + + public RedisRateLimitTracker(int expectedCount) + { + _expectedCount = expectedCount; + } + + public IReadOnlyList HandledTimes + { + get + { + lock (_lock) + { + return _handledTimes.ToList(); + } + } + } + + public void RecordHandled() + { + lock (_lock) + { + _handledTimes.Add(DateTimeOffset.UtcNow); + if (_handledTimes.Count >= _expectedCount) + { + _completion.TrySetResult(true); + } + } + } + + public async Task WaitForHandledAsync(TimeSpan timeout) + { + using var cts = new CancellationTokenSource(timeout); + await _completion.Task.WaitAsync(cts.Token); + } +} diff --git a/src/Wolverine/MessageTypePolicies.cs b/src/Wolverine/MessageTypePolicies.cs index b8a37bb66..b7ba0e6e1 100644 --- a/src/Wolverine/MessageTypePolicies.cs +++ b/src/Wolverine/MessageTypePolicies.cs @@ -1,6 +1,7 @@ using System.Linq.Expressions; using JasperFx.Core.Reflection; using Wolverine.Logging; +using Wolverine.RateLimiting; namespace Wolverine; @@ -51,4 +52,25 @@ public MessageTypePolicies AddMiddleware() { return AddMiddleware(typeof(TMiddleware)); } + + /// + /// Apply a rate limit schedule to messages assignable to type T + /// + public MessageTypePolicies RateLimit(RateLimit defaultLimit, Action? configure = null) + { + return RateLimit(null, defaultLimit, configure); + } + + /// + /// Apply a rate limit schedule to messages assignable to type T using a shared key + /// + public MessageTypePolicies RateLimit(string? key, RateLimit defaultLimit, + Action? configure = null) + { + var schedule = new RateLimitSchedule(defaultLimit); + configure?.Invoke(schedule); + _parent.ConfigureRateLimit(typeof(T), schedule, key); + + return this; + } } \ No newline at end of file diff --git a/src/Wolverine/RateLimiting/IRateLimitStore.cs b/src/Wolverine/RateLimiting/IRateLimitStore.cs new file mode 100644 index 000000000..6f230e46e --- /dev/null +++ b/src/Wolverine/RateLimiting/IRateLimitStore.cs @@ -0,0 +1,71 @@ +using System.Collections.Concurrent; + +namespace Wolverine.RateLimiting; + +public interface IRateLimitStore +{ + ValueTask TryAcquireAsync(RateLimitStoreRequest request, + CancellationToken cancellationToken); +} + +public sealed record RateLimitBucket(DateTimeOffset WindowStart, DateTimeOffset WindowEnd, int Limit) +{ + public static RateLimitBucket For(RateLimit limit, DateTimeOffset now) + { + var utcNow = now.UtcDateTime; + var windowTicks = limit.Window.Ticks; + var windowStartTicks = utcNow.Ticks - (utcNow.Ticks % windowTicks); + var windowStart = new DateTimeOffset(windowStartTicks, TimeSpan.Zero); + var windowEnd = windowStart.Add(limit.Window); + + return new RateLimitBucket(windowStart, windowEnd, limit.Permits); + } +} + +public sealed record RateLimitStoreRequest(string Key, RateLimitBucket Bucket, int Quantity, DateTimeOffset Now); + +public sealed record RateLimitStoreResult(bool Allowed, int CurrentCount); + +public sealed class InMemoryRateLimitStore : IRateLimitStore +{ + private readonly ConcurrentDictionary _buckets = new(); + + public ValueTask TryAcquireAsync(RateLimitStoreRequest request, + CancellationToken cancellationToken) + { + var bucketKey = $"{request.Key}:{request.Bucket.WindowStart.UtcTicks}"; + var state = _buckets.GetOrAdd(bucketKey, _ => new BucketState(request.Bucket.WindowEnd)); + var now = request.Now; + bool allowed; + int currentCount; + + lock (state.Lock) + { + if (state.WindowEnd <= now) + { + state.WindowEnd = request.Bucket.WindowEnd; + state.Count = 0; + } + + state.Count += request.Quantity; + currentCount = state.Count; + allowed = currentCount <= request.Bucket.Limit; + } + + return new ValueTask(new RateLimitStoreResult(allowed, currentCount)); + } + + private sealed class BucketState + { + public BucketState(DateTimeOffset windowEnd) + { + WindowEnd = windowEnd; + } + + public object Lock { get; } = new(); + + public int Count { get; set; } + + public DateTimeOffset WindowEnd { get; set; } + } +} diff --git a/src/Wolverine/RateLimiting/RateLimit.cs b/src/Wolverine/RateLimiting/RateLimit.cs new file mode 100644 index 000000000..3fb522437 --- /dev/null +++ b/src/Wolverine/RateLimiting/RateLimit.cs @@ -0,0 +1,94 @@ +using JasperFx.Core; + +namespace Wolverine.RateLimiting; + +public readonly record struct RateLimit(int Permits, TimeSpan Window) +{ + public static RateLimit PerSecond(int permits) => new(permits, 1.Seconds()); + public static RateLimit PerMinute(int permits) => new(permits, 1.Minutes()); + public static RateLimit PerHour(int permits) => new(permits, 1.Hours()); +} + +public sealed record RateLimitWindow(TimeOnly Start, TimeOnly End, RateLimit Limit) +{ + public bool Matches(TimeOnly time) + { + if (Start == End) + { + return true; + } + + if (Start < End) + { + return time >= Start && time < End; + } + + return time >= Start || time < End; + } +} + +public sealed class RateLimitSchedule +{ + private readonly List _windows = []; + + public RateLimitSchedule(RateLimit defaultLimit) + { + DefaultLimit = defaultLimit; + } + + public RateLimit DefaultLimit { get; } + + public TimeZoneInfo TimeZone { get; set; } = TimeZoneInfo.Local; + + public IReadOnlyList Windows => _windows; + + public RateLimitSchedule AddWindow(TimeOnly start, TimeOnly end, RateLimit limit) + { + _windows.Add(new RateLimitWindow(start, end, limit)); + return this; + } + + public RateLimit Resolve(DateTimeOffset now) + { + validate(DefaultLimit); + + var localTime = TimeZoneInfo.ConvertTime(now, TimeZone); + var time = TimeOnly.FromDateTime(localTime.DateTime); + foreach (var window in _windows) + { + if (window.Matches(time)) + { + validate(window.Limit); + return window.Limit; + } + } + + return DefaultLimit; + } + + private static void validate(RateLimit limit) + { + if (limit.Permits <= 0) + { + throw new ArgumentOutOfRangeException(nameof(limit.Permits), "Rate limit permits must be greater than zero."); + } + + if (limit.Window <= TimeSpan.Zero) + { + throw new ArgumentOutOfRangeException(nameof(limit.Window), "Rate limit window must be greater than zero."); + } + } +} + +public sealed class RateLimitSettings +{ + public RateLimitSettings(string key, RateLimitSchedule schedule) + { + Key = key ?? throw new ArgumentNullException(nameof(key)); + Schedule = schedule ?? throw new ArgumentNullException(nameof(schedule)); + } + + public string Key { get; } + + public RateLimitSchedule Schedule { get; } +} diff --git a/src/Wolverine/RateLimiting/RateLimitContinuation.cs b/src/Wolverine/RateLimiting/RateLimitContinuation.cs new file mode 100644 index 000000000..3bb1159df --- /dev/null +++ b/src/Wolverine/RateLimiting/RateLimitContinuation.cs @@ -0,0 +1,40 @@ +using System.Diagnostics; +using JasperFx.Core; +using Wolverine; +using Wolverine.ErrorHandling; +using Wolverine.Runtime; + +namespace Wolverine.RateLimiting; + +internal sealed class RateLimitContinuationSource : IContinuationSource +{ + public string Description => "Rate limit exceeded; reschedule and pause listener"; + + public IContinuation Build(Exception ex, Envelope envelope) + { + if (ex is RateLimitExceededException rateLimit) + { + return new RateLimitContinuation(rateLimit.RetryAfter); + } + + return new RateLimitContinuation(1.Milliseconds()); + } +} + +internal sealed class RateLimitContinuation : IContinuation +{ + public RateLimitContinuation(TimeSpan pauseTime) + { + PauseTime = pauseTime <= TimeSpan.Zero ? 1.Milliseconds() : pauseTime; + } + + public TimeSpan PauseTime { get; } + + public async ValueTask ExecuteAsync(IEnvelopeLifecycle lifecycle, IWolverineRuntime runtime, DateTimeOffset now, + Activity? activity) + { + await lifecycle.ReScheduleAsync(now.Add(PauseTime)).ConfigureAwait(false); + await new PauseListenerContinuation(PauseTime) + .ExecuteAsync(lifecycle, runtime, now, activity).ConfigureAwait(false); + } +} diff --git a/src/Wolverine/RateLimiting/RateLimitMiddleware.cs b/src/Wolverine/RateLimiting/RateLimitMiddleware.cs new file mode 100644 index 000000000..ec83f13fe --- /dev/null +++ b/src/Wolverine/RateLimiting/RateLimitMiddleware.cs @@ -0,0 +1,42 @@ +using Wolverine; + +namespace Wolverine.RateLimiting; + +public sealed class RateLimitMiddleware +{ + private readonly RateLimiter _limiter; + + public RateLimitMiddleware(RateLimiter limiter) + { + _limiter = limiter; + } + + public async Task BeforeAsync(Envelope envelope, CancellationToken cancellationToken) + { + var check = await _limiter.CheckAsync(envelope, cancellationToken).ConfigureAwait(false); + if (check.Allowed) + { + return; + } + + throw new RateLimitExceededException(check.Settings!.Key, check.Limit!.Value, + check.RetryAfter!.Value); + } +} + +public class RateLimitExceededException : Exception +{ + public RateLimitExceededException(string key, RateLimit limit, TimeSpan retryAfter) + : base($"Rate limit exceeded for '{key}'. Retry after {retryAfter.TotalSeconds:0.###} seconds.") + { + Key = key; + Limit = limit; + RetryAfter = retryAfter; + } + + public string Key { get; } + + public RateLimit Limit { get; } + + public TimeSpan RetryAfter { get; } +} diff --git a/src/Wolverine/RateLimiting/RateLimitSettingsRegistry.cs b/src/Wolverine/RateLimiting/RateLimitSettingsRegistry.cs new file mode 100644 index 000000000..65ba8204b --- /dev/null +++ b/src/Wolverine/RateLimiting/RateLimitSettingsRegistry.cs @@ -0,0 +1,46 @@ +using JasperFx.Core.Reflection; + +namespace Wolverine.RateLimiting; + +internal sealed class RateLimitSettingsRegistry +{ + private readonly Dictionary _messageTypeLimits = new(); + private readonly Dictionary _endpointLimits = new(StringComparer.OrdinalIgnoreCase); + + public bool HasAny => _messageTypeLimits.Count > 0 || _endpointLimits.Count > 0; + + public void RegisterMessageType(Type messageType, RateLimitSettings settings) + { + _messageTypeLimits[messageType] = settings; + } + + public void RegisterEndpoint(Uri endpoint, RateLimitSettings settings) + { + _endpointLimits[endpoint.ToString()] = settings; + } + + public bool TryFindForMessageType(Type messageType, out RateLimitSettings settings) + { + if (_messageTypeLimits.TryGetValue(messageType, out settings)) + { + return true; + } + + foreach (var pair in _messageTypeLimits) + { + if (messageType.CanBeCastTo(pair.Key)) + { + settings = pair.Value; + return true; + } + } + + settings = null!; + return false; + } + + public bool TryFindForEndpoint(Uri endpoint, out RateLimitSettings settings) + { + return _endpointLimits.TryGetValue(endpoint.ToString(), out settings); + } +} diff --git a/src/Wolverine/RateLimiting/RateLimiter.cs b/src/Wolverine/RateLimiting/RateLimiter.cs new file mode 100644 index 000000000..07415df99 --- /dev/null +++ b/src/Wolverine/RateLimiting/RateLimiter.cs @@ -0,0 +1,87 @@ +using Microsoft.Extensions.Logging; +using Wolverine; + +namespace Wolverine.RateLimiting; + +public sealed class RateLimiter +{ + private readonly IRateLimitStore _store; + private readonly RateLimitSettingsRegistry _settings; + private readonly ILogger _logger; + + public RateLimiter(IRateLimitStore store, WolverineOptions options, ILogger logger) + { + _store = store; + _settings = options.RateLimits; + _logger = logger; + } + + public ValueTask CheckAsync(Envelope envelope, CancellationToken cancellationToken) + { + return CheckAsync(envelope, DateTimeOffset.UtcNow, cancellationToken); + } + + internal async ValueTask CheckAsync(Envelope envelope, DateTimeOffset now, + CancellationToken cancellationToken) + { + if (!tryResolveSettings(envelope, out var settings)) + { + return RateLimitCheck.AllowedCheck(); + } + + var limit = settings.Schedule.Resolve(now); + var bucket = RateLimitBucket.For(limit, now); + var result = await _store.TryAcquireAsync( + new RateLimitStoreRequest(settings.Key, bucket, 1, now), + cancellationToken).ConfigureAwait(false); + + if (result.Allowed) + { + return RateLimitCheck.AllowedCheck(settings, limit); + } + + var retryAfter = bucket.WindowEnd - now; + if (retryAfter < TimeSpan.Zero) + { + retryAfter = TimeSpan.Zero; + } + + _logger.LogDebug("Rate limit exceeded for {Key}. Retry after {RetryAfter}", settings.Key, retryAfter); + return RateLimitCheck.Denied(settings, limit, retryAfter); + } + + private bool tryResolveSettings(Envelope envelope, out RateLimitSettings settings) + { + if (envelope.Listener != null && _settings.TryFindForEndpoint(envelope.Listener.Address, out settings)) + { + return true; + } + + if (envelope.Destination != null && _settings.TryFindForEndpoint(envelope.Destination, out settings)) + { + return true; + } + + if (envelope.Message != null && _settings.TryFindForMessageType(envelope.Message.GetType(), out settings)) + { + return true; + } + + settings = null!; + return false; + } +} + +public sealed record RateLimitCheck(bool Allowed, RateLimitSettings? Settings, RateLimit? Limit, + TimeSpan? RetryAfter) +{ + public static RateLimitCheck AllowedCheck(RateLimitSettings? settings = null, RateLimit? limit = null) + { + return new RateLimitCheck(true, settings, limit, null); + } + + public static RateLimitCheck Denied(RateLimitSettings settings, RateLimit limit, TimeSpan retryAfter) + { + return new RateLimitCheck(false, settings, limit, retryAfter); + } +} diff --git a/src/Wolverine/WolverineOptions.RateLimiting.cs b/src/Wolverine/WolverineOptions.RateLimiting.cs new file mode 100644 index 000000000..2f3f11b86 --- /dev/null +++ b/src/Wolverine/WolverineOptions.RateLimiting.cs @@ -0,0 +1,65 @@ +using Microsoft.Extensions.DependencyInjection.Extensions; +using Wolverine.ErrorHandling; +using Wolverine.RateLimiting; +using Wolverine.Util; + +namespace Wolverine; + +public sealed partial class WolverineOptions +{ + internal RateLimitSettingsRegistry RateLimits { get; } = new(); + + private bool _rateLimitingConfigured; + + internal void ConfigureRateLimit(Type messageType, RateLimitSchedule schedule, string? key = null) + { + var rateLimitKey = key ?? messageType.ToMessageTypeName(); + RateLimits.RegisterMessageType(messageType, new RateLimitSettings(rateLimitKey, schedule)); + + ensureRateLimitingConfigured(); + } + + public WolverineOptions RateLimitEndpoint(Uri endpoint, RateLimit defaultLimit, + Action? configure = null, string? key = null) + { + var schedule = new RateLimitSchedule(defaultLimit); + configure?.Invoke(schedule); + ConfigureEndpointRateLimit(endpoint, schedule, key); + + return this; + } + + internal void ConfigureEndpointRateLimit(Uri endpoint, RateLimitSchedule schedule, string? key = null) + { + var rateLimitKey = key ?? endpoint.ToString(); + RateLimits.RegisterEndpoint(endpoint, new RateLimitSettings(rateLimitKey, schedule)); + + ensureRateLimitingConfigured(); + } + + private void ensureRateLimitingConfigured() + { + if (_rateLimitingConfigured) + { + return; + } + + _rateLimitingConfigured = true; + + Policies.AddMiddleware(); + this.OnException() + .CustomActionIndefinitely(async (runtime, lifecycle, ex) => + { + if (lifecycle.Envelope == null) + { + return; + } + + var continuation = new RateLimitContinuationSource().Build(ex, lifecycle.Envelope); + await continuation.ExecuteAsync(lifecycle, runtime, DateTimeOffset.UtcNow, null).ConfigureAwait(false); + }, "Rate limit exceeded"); + + Services.TryAddSingleton(); + Services.TryAddSingleton(); + } +}