diff --git a/src/ReverseProxy.Core/Abstractions/BackendDiscovery/Contract/LoadBalancingMode.cs b/src/ReverseProxy.Core/Abstractions/BackendDiscovery/Contract/LoadBalancingMode.cs index f6d11a411a..c6c890e625 100644 --- a/src/ReverseProxy.Core/Abstractions/BackendDiscovery/Contract/LoadBalancingMode.cs +++ b/src/ReverseProxy.Core/Abstractions/BackendDiscovery/Contract/LoadBalancingMode.cs @@ -22,6 +22,10 @@ public enum LoadBalancingMode /// Random, /// + /// Selects an endpoint by cycling through them in order. + /// + RoundRobin, + /// /// Select the first endpoint without considering load. This is useful for dual endpoint fail-over systems. /// First, diff --git a/src/ReverseProxy.Core/Middleware/LoadBalancingMiddleware.cs b/src/ReverseProxy.Core/Middleware/LoadBalancingMiddleware.cs index 48b7edb75c..43fd427190 100644 --- a/src/ReverseProxy.Core/Middleware/LoadBalancingMiddleware.cs +++ b/src/ReverseProxy.Core/Middleware/LoadBalancingMiddleware.cs @@ -41,7 +41,8 @@ public Task Invoke(HttpContext context) var endpoints = endpointsFeature?.Endpoints ?? throw new InvalidOperationException("The AvailableBackendEndpoints collection was not set."); - var loadBalancingOptions = backend.Config.Value?.LoadBalancingOptions ?? default; + var loadBalancingOptions = backend.Config.Value?.LoadBalancingOptions + ?? new BackendConfig.BackendLoadBalancingOptions(default); var endpoint = _operationLogger.Execute( "ReverseProxy.PickEndpoint", diff --git a/src/ReverseProxy.Core/Service/Proxy/LoadBalancer.cs b/src/ReverseProxy.Core/Service/Proxy/LoadBalancer.cs index cdc4b6a29d..6f831d46ef 100644 --- a/src/ReverseProxy.Core/Service/Proxy/LoadBalancer.cs +++ b/src/ReverseProxy.Core/Service/Proxy/LoadBalancer.cs @@ -40,6 +40,9 @@ public EndpointInfo PickEndpoint( { case LoadBalancingMode.First: return endpoints[0]; + case LoadBalancingMode.RoundRobin: + var offset = loadBalancingOptions.RoundRobinCounter.Increment(); + return endpoints[offset % endpoints.Count]; case LoadBalancingMode.Random: var random = _randomFactory.CreateRandomInstance(); return endpoints[random.Next(endpointCount)]; diff --git a/src/ReverseProxy.Core/Service/RuntimeModel/BackendConfig.cs b/src/ReverseProxy.Core/Service/RuntimeModel/BackendConfig.cs index d3f1706bbc..aa31ee0a2a 100644 --- a/src/ReverseProxy.Core/Service/RuntimeModel/BackendConfig.cs +++ b/src/ReverseProxy.Core/Service/RuntimeModel/BackendConfig.cs @@ -3,6 +3,7 @@ using System; using Microsoft.ReverseProxy.Core.Abstractions; +using Microsoft.ReverseProxy.Core.Util; namespace Microsoft.ReverseProxy.Core.RuntimeModel { @@ -79,9 +80,13 @@ internal readonly struct BackendLoadBalancingOptions public BackendLoadBalancingOptions(LoadBalancingMode mode) { Mode = mode; + // Increment returns the new value and we want the first return value to be 0. + RoundRobinCounter = new AtomicCounter() { Value = -1 }; } public LoadBalancingMode Mode { get; } + + public AtomicCounter RoundRobinCounter { get; } } } } diff --git a/src/ReverseProxy.Core/Util/AtomicCounter.cs b/src/ReverseProxy.Core/Util/AtomicCounter.cs index d1a715f8cd..7094d71ecf 100644 --- a/src/ReverseProxy.Core/Util/AtomicCounter.cs +++ b/src/ReverseProxy.Core/Util/AtomicCounter.cs @@ -12,22 +12,25 @@ public class AtomicCounter /// /// Gets the current value of the counter. /// - public int Value => Volatile.Read(ref _value); + public int Value { + get => Volatile.Read(ref _value); + set => Volatile.Write(ref _value, value); + } /// /// Atomically increments the counter value by 1. /// - public void Increment() + public int Increment() { - Interlocked.Increment(ref _value); + return Interlocked.Increment(ref _value); } /// /// Atomically decrements the counter value by 1. /// - public void Decrement() + public int Decrement() { - Interlocked.Decrement(ref _value); + return Interlocked.Decrement(ref _value); } /// diff --git a/test/ReverseProxy.Core.Tests/Middleware/LoadBalancerMiddlewareTests.cs b/test/ReverseProxy.Core.Tests/Middleware/LoadBalancerMiddlewareTests.cs index c381da6eb9..01a32ab386 100644 --- a/test/ReverseProxy.Core.Tests/Middleware/LoadBalancerMiddlewareTests.cs +++ b/test/ReverseProxy.Core.Tests/Middleware/LoadBalancerMiddlewareTests.cs @@ -8,6 +8,7 @@ using Microsoft.AspNetCore.Routing.Patterns; using Microsoft.ReverseProxy.Common.Abstractions.Telemetry; using Microsoft.ReverseProxy.Common.Telemetry; +using Microsoft.ReverseProxy.Core.Abstractions; using Microsoft.ReverseProxy.Core.RuntimeModel; using Microsoft.ReverseProxy.Core.Service.Management; using Microsoft.ReverseProxy.Core.Service.Proxy; @@ -40,6 +41,7 @@ public async Task Invoke_Works() backendId: "backend1", endpointManager: new EndpointManager(), proxyHttpClientFactory: proxyHttpClientFactoryMock.Object); + backend1.Config.Value = new BackendConfig(default, new BackendConfig.BackendLoadBalancingOptions(LoadBalancingMode.RoundRobin)); var endpoint1 = backend1.EndpointManager.GetOrCreateItem( "endpoint1", endpoint => @@ -66,10 +68,7 @@ public async Task Invoke_Works() aspNetCoreEndpoints.Add(aspNetCoreEndpoint); var httpContext = new DefaultHttpContext(); httpContext.SetEndpoint(aspNetCoreEndpoint); - - Mock() - .Setup(l => l.PickEndpoint(It.IsAny>(), It.IsAny())) - .Returns(endpoint1); + Provide(); httpContext.Features.Set( new AvailableBackendEndpointsFeature() { Endpoints = new List() { endpoint1, endpoint2 }.AsReadOnly() }); diff --git a/test/ReverseProxy.Core.Tests/Service/Proxy/LoadBalancerTests.cs b/test/ReverseProxy.Core.Tests/Service/Proxy/LoadBalancerTests.cs index 4345a68f4f..2b7c067b50 100644 --- a/test/ReverseProxy.Core.Tests/Service/Proxy/LoadBalancerTests.cs +++ b/test/ReverseProxy.Core.Tests/Service/Proxy/LoadBalancerTests.cs @@ -135,6 +135,29 @@ public void PickEndpoint_LeastRequests_Works() Assert.Same(result, endpoints[1]); } + [Fact] + public void PickEndpoint_RoundRobin_Works() + { + var loadBalancer = Create(); + var endpoints = new[] + { + new EndpointInfo("ep1"), + new EndpointInfo("ep2"), + }; + endpoints[0].ConcurrencyCounter.Increment(); + var options = new BackendConfig.BackendLoadBalancingOptions(LoadBalancingMode.RoundRobin); + + var result0 = loadBalancer.PickEndpoint(endpoints, in options); + var result1 = loadBalancer.PickEndpoint(endpoints, in options); + var result2 = loadBalancer.PickEndpoint(endpoints, in options); + var result3 = loadBalancer.PickEndpoint(endpoints, in options); + + Assert.Same(result0, endpoints[0]); + Assert.Same(result1, endpoints[1]); + Assert.Same(result2, endpoints[0]); + Assert.Same(result3, endpoints[1]); + } + internal class TestRandomFactory : IRandomFactory { internal TestRandom Instance { get; set; }