diff --git a/src/ReverseProxy.Core/Abstractions/BackendDiscovery/Contract/LoadBalancingMode.cs b/src/ReverseProxy.Core/Abstractions/BackendDiscovery/Contract/LoadBalancingMode.cs
index f6d11a411a..c6c890e625 100644
--- a/src/ReverseProxy.Core/Abstractions/BackendDiscovery/Contract/LoadBalancingMode.cs
+++ b/src/ReverseProxy.Core/Abstractions/BackendDiscovery/Contract/LoadBalancingMode.cs
@@ -22,6 +22,10 @@ public enum LoadBalancingMode
///
Random,
///
+ /// Selects an endpoint by cycling through them in order.
+ ///
+ RoundRobin,
+ ///
/// Select the first endpoint without considering load. This is useful for dual endpoint fail-over systems.
///
First,
diff --git a/src/ReverseProxy.Core/Middleware/LoadBalancingMiddleware.cs b/src/ReverseProxy.Core/Middleware/LoadBalancingMiddleware.cs
index 48b7edb75c..43fd427190 100644
--- a/src/ReverseProxy.Core/Middleware/LoadBalancingMiddleware.cs
+++ b/src/ReverseProxy.Core/Middleware/LoadBalancingMiddleware.cs
@@ -41,7 +41,8 @@ public Task Invoke(HttpContext context)
var endpoints = endpointsFeature?.Endpoints
?? throw new InvalidOperationException("The AvailableBackendEndpoints collection was not set.");
- var loadBalancingOptions = backend.Config.Value?.LoadBalancingOptions ?? default;
+ var loadBalancingOptions = backend.Config.Value?.LoadBalancingOptions
+ ?? new BackendConfig.BackendLoadBalancingOptions(default);
var endpoint = _operationLogger.Execute(
"ReverseProxy.PickEndpoint",
diff --git a/src/ReverseProxy.Core/Service/Proxy/LoadBalancer.cs b/src/ReverseProxy.Core/Service/Proxy/LoadBalancer.cs
index cdc4b6a29d..6f831d46ef 100644
--- a/src/ReverseProxy.Core/Service/Proxy/LoadBalancer.cs
+++ b/src/ReverseProxy.Core/Service/Proxy/LoadBalancer.cs
@@ -40,6 +40,9 @@ public EndpointInfo PickEndpoint(
{
case LoadBalancingMode.First:
return endpoints[0];
+ case LoadBalancingMode.RoundRobin:
+ var offset = loadBalancingOptions.RoundRobinCounter.Increment();
+ return endpoints[offset % endpoints.Count];
case LoadBalancingMode.Random:
var random = _randomFactory.CreateRandomInstance();
return endpoints[random.Next(endpointCount)];
diff --git a/src/ReverseProxy.Core/Service/RuntimeModel/BackendConfig.cs b/src/ReverseProxy.Core/Service/RuntimeModel/BackendConfig.cs
index d3f1706bbc..aa31ee0a2a 100644
--- a/src/ReverseProxy.Core/Service/RuntimeModel/BackendConfig.cs
+++ b/src/ReverseProxy.Core/Service/RuntimeModel/BackendConfig.cs
@@ -3,6 +3,7 @@
using System;
using Microsoft.ReverseProxy.Core.Abstractions;
+using Microsoft.ReverseProxy.Core.Util;
namespace Microsoft.ReverseProxy.Core.RuntimeModel
{
@@ -79,9 +80,13 @@ internal readonly struct BackendLoadBalancingOptions
public BackendLoadBalancingOptions(LoadBalancingMode mode)
{
Mode = mode;
+ // Increment returns the new value and we want the first return value to be 0.
+ RoundRobinCounter = new AtomicCounter() { Value = -1 };
}
public LoadBalancingMode Mode { get; }
+
+ public AtomicCounter RoundRobinCounter { get; }
}
}
}
diff --git a/src/ReverseProxy.Core/Util/AtomicCounter.cs b/src/ReverseProxy.Core/Util/AtomicCounter.cs
index d1a715f8cd..7094d71ecf 100644
--- a/src/ReverseProxy.Core/Util/AtomicCounter.cs
+++ b/src/ReverseProxy.Core/Util/AtomicCounter.cs
@@ -12,22 +12,25 @@ public class AtomicCounter
///
/// Gets the current value of the counter.
///
- public int Value => Volatile.Read(ref _value);
+ public int Value {
+ get => Volatile.Read(ref _value);
+ set => Volatile.Write(ref _value, value);
+ }
///
/// Atomically increments the counter value by 1.
///
- public void Increment()
+ public int Increment()
{
- Interlocked.Increment(ref _value);
+ return Interlocked.Increment(ref _value);
}
///
/// Atomically decrements the counter value by 1.
///
- public void Decrement()
+ public int Decrement()
{
- Interlocked.Decrement(ref _value);
+ return Interlocked.Decrement(ref _value);
}
///
diff --git a/test/ReverseProxy.Core.Tests/Middleware/LoadBalancerMiddlewareTests.cs b/test/ReverseProxy.Core.Tests/Middleware/LoadBalancerMiddlewareTests.cs
index c381da6eb9..01a32ab386 100644
--- a/test/ReverseProxy.Core.Tests/Middleware/LoadBalancerMiddlewareTests.cs
+++ b/test/ReverseProxy.Core.Tests/Middleware/LoadBalancerMiddlewareTests.cs
@@ -8,6 +8,7 @@
using Microsoft.AspNetCore.Routing.Patterns;
using Microsoft.ReverseProxy.Common.Abstractions.Telemetry;
using Microsoft.ReverseProxy.Common.Telemetry;
+using Microsoft.ReverseProxy.Core.Abstractions;
using Microsoft.ReverseProxy.Core.RuntimeModel;
using Microsoft.ReverseProxy.Core.Service.Management;
using Microsoft.ReverseProxy.Core.Service.Proxy;
@@ -40,6 +41,7 @@ public async Task Invoke_Works()
backendId: "backend1",
endpointManager: new EndpointManager(),
proxyHttpClientFactory: proxyHttpClientFactoryMock.Object);
+ backend1.Config.Value = new BackendConfig(default, new BackendConfig.BackendLoadBalancingOptions(LoadBalancingMode.RoundRobin));
var endpoint1 = backend1.EndpointManager.GetOrCreateItem(
"endpoint1",
endpoint =>
@@ -66,10 +68,7 @@ public async Task Invoke_Works()
aspNetCoreEndpoints.Add(aspNetCoreEndpoint);
var httpContext = new DefaultHttpContext();
httpContext.SetEndpoint(aspNetCoreEndpoint);
-
- Mock()
- .Setup(l => l.PickEndpoint(It.IsAny>(), It.IsAny()))
- .Returns(endpoint1);
+ Provide();
httpContext.Features.Set(
new AvailableBackendEndpointsFeature() { Endpoints = new List() { endpoint1, endpoint2 }.AsReadOnly() });
diff --git a/test/ReverseProxy.Core.Tests/Service/Proxy/LoadBalancerTests.cs b/test/ReverseProxy.Core.Tests/Service/Proxy/LoadBalancerTests.cs
index 4345a68f4f..2b7c067b50 100644
--- a/test/ReverseProxy.Core.Tests/Service/Proxy/LoadBalancerTests.cs
+++ b/test/ReverseProxy.Core.Tests/Service/Proxy/LoadBalancerTests.cs
@@ -135,6 +135,29 @@ public void PickEndpoint_LeastRequests_Works()
Assert.Same(result, endpoints[1]);
}
+ [Fact]
+ public void PickEndpoint_RoundRobin_Works()
+ {
+ var loadBalancer = Create();
+ var endpoints = new[]
+ {
+ new EndpointInfo("ep1"),
+ new EndpointInfo("ep2"),
+ };
+ endpoints[0].ConcurrencyCounter.Increment();
+ var options = new BackendConfig.BackendLoadBalancingOptions(LoadBalancingMode.RoundRobin);
+
+ var result0 = loadBalancer.PickEndpoint(endpoints, in options);
+ var result1 = loadBalancer.PickEndpoint(endpoints, in options);
+ var result2 = loadBalancer.PickEndpoint(endpoints, in options);
+ var result3 = loadBalancer.PickEndpoint(endpoints, in options);
+
+ Assert.Same(result0, endpoints[0]);
+ Assert.Same(result1, endpoints[1]);
+ Assert.Same(result2, endpoints[0]);
+ Assert.Same(result3, endpoints[1]);
+ }
+
internal class TestRandomFactory : IRandomFactory
{
internal TestRandom Instance { get; set; }