From 311f8784c6f2300f85875927cbbdc7d29f0e2538 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Wed, 30 Aug 2023 11:56:19 +0800 Subject: [PATCH] Change subchannel BalancerAddress when attributes change (#2228) --- .../Balancer/BalancerAddress.cs | 24 ++- .../Balancer/BalancerAttributes.cs | 93 ++++++-- .../BalancerAddressEqualityComparer.cs | 6 +- .../Balancer/Internal/BalancerHttpHandler.cs | 2 +- .../Balancer/Internal/ISubchannelTransport.cs | 5 +- .../Internal/PassiveSubchannelTransport.cs | 14 +- .../SocketConnectivitySubchannelTransport.cs | 200 +++++++++--------- src/Grpc.Net.Client/Balancer/Subchannel.cs | 41 +++- .../Balancer/SubchannelsLoadBalancer.cs | 8 + .../Balancer/ConnectionTests.cs | 6 +- .../Balancer/PickFirstBalancerTests.cs | 6 +- .../Balancer/RoundRobinBalancerTests.cs | 16 +- .../Balancer/RoundRobinBalancerTests.cs | 28 ++- .../Balancer/TestSubChannelTransport.cs | 8 +- 14 files changed, 299 insertions(+), 158 deletions(-) diff --git a/src/Grpc.Net.Client/Balancer/BalancerAddress.cs b/src/Grpc.Net.Client/Balancer/BalancerAddress.cs index 0c6322cb1..dd9bfe27a 100644 --- a/src/Grpc.Net.Client/Balancer/BalancerAddress.cs +++ b/src/Grpc.Net.Client/Balancer/BalancerAddress.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -30,7 +30,9 @@ namespace Grpc.Net.Client.Balancer; /// public sealed class BalancerAddress { - private BalancerAttributes? _attributes; + // Internal so address attributes can be compared without using the Attributes property. + // The property allocates an empty collection if one isn't already present. + internal BalancerAttributes? _attributes; /// /// Initializes a new instance of the class with the specified . @@ -48,7 +50,7 @@ public BalancerAddress(DnsEndPoint endPoint) /// The host. /// The port. [DebuggerStepThrough] - public BalancerAddress(string host, int port) : this(new DnsEndPoint(host, port)) + public BalancerAddress(string host, int port) : this(new BalancerEndPoint(host, port)) { } @@ -69,5 +71,21 @@ public override string ToString() { return $"{EndPoint.Host}:{EndPoint.Port}"; } + + private sealed class BalancerEndPoint : DnsEndPoint + { + private string? _cachedToString; + + public BalancerEndPoint(string host, int port) : base(host, port) + { + } + + public override string ToString() + { + // Improve ToString performance when logging by caching ToString. + // Don't include DnsEndPoint address family. + return _cachedToString ??= $"{Host}:{Port}"; + } + } } #endif diff --git a/src/Grpc.Net.Client/Balancer/BalancerAttributes.cs b/src/Grpc.Net.Client/Balancer/BalancerAttributes.cs index 4d7069b19..c75095c08 100644 --- a/src/Grpc.Net.Client/Balancer/BalancerAttributes.cs +++ b/src/Grpc.Net.Client/Balancer/BalancerAttributes.cs @@ -38,20 +38,22 @@ public sealed class BalancerAttributes : IDictionary, IReadOnly /// /// Gets a read-only collection of metadata attributes. /// - public static readonly BalancerAttributes Empty = new BalancerAttributes(new ReadOnlyDictionary(new Dictionary())); + public static readonly BalancerAttributes Empty = new BalancerAttributes(new Dictionary(), readOnly: true); - private readonly IDictionary _attributes; + private readonly Dictionary _attributes; + private readonly bool _readOnly; /// /// Initializes a new instance of the class. /// - public BalancerAttributes() : this(new Dictionary()) + public BalancerAttributes() : this(new Dictionary(), readOnly: false) { } - private BalancerAttributes(IDictionary attributes) + private BalancerAttributes(Dictionary attributes, bool readOnly) { _attributes = attributes; + _readOnly = readOnly; } object? IDictionary.this[string key] @@ -62,6 +64,7 @@ private BalancerAttributes(IDictionary attributes) } set { + ValidateReadOnly(); _attributes[key] = value; } } @@ -69,21 +72,41 @@ private BalancerAttributes(IDictionary attributes) ICollection IDictionary.Keys => _attributes.Keys; ICollection IDictionary.Values => _attributes.Values; int ICollection>.Count => _attributes.Count; - bool ICollection>.IsReadOnly => _attributes.IsReadOnly; + bool ICollection>.IsReadOnly => _readOnly || ((ICollection>)_attributes).IsReadOnly; IEnumerable IReadOnlyDictionary.Keys => _attributes.Keys; IEnumerable IReadOnlyDictionary.Values => _attributes.Values; int IReadOnlyCollection>.Count => _attributes.Count; object? IReadOnlyDictionary.this[string key] => _attributes[key]; - void IDictionary.Add(string key, object? value) => _attributes.Add(key, value); - void ICollection>.Add(KeyValuePair item) => _attributes.Add(item); - void ICollection>.Clear() => _attributes.Clear(); + void IDictionary.Add(string key, object? value) + { + ValidateReadOnly(); + _attributes.Add(key, value); + } + void ICollection>.Add(KeyValuePair item) + { + ValidateReadOnly(); + ((ICollection>)_attributes).Add(item); + } + void ICollection>.Clear() + { + ValidateReadOnly(); + _attributes.Clear(); + } bool ICollection>.Contains(KeyValuePair item) => _attributes.Contains(item); bool IDictionary.ContainsKey(string key) => _attributes.ContainsKey(key); - void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) => _attributes.CopyTo(array, arrayIndex); + void ICollection>.CopyTo(KeyValuePair[] array, int arrayIndex) => ((ICollection>)_attributes).CopyTo(array, arrayIndex); IEnumerator> IEnumerable>.GetEnumerator() => _attributes.GetEnumerator(); IEnumerator System.Collections.IEnumerable.GetEnumerator() => ((System.Collections.IEnumerable)_attributes).GetEnumerator(); - bool IDictionary.Remove(string key) => _attributes.Remove(key); - bool ICollection>.Remove(KeyValuePair item) => _attributes.Remove(item); + bool IDictionary.Remove(string key) + { + ValidateReadOnly(); + return _attributes.Remove(key); + } + bool ICollection>.Remove(KeyValuePair item) + { + ValidateReadOnly(); + return ((ICollection>)_attributes).Remove(item); + } bool IDictionary.TryGetValue(string key, out object? value) => _attributes.TryGetValue(key, out value); bool IReadOnlyDictionary.ContainsKey(string key) => _attributes.ContainsKey(key); bool IReadOnlyDictionary.TryGetValue(string key, out object? value) => _attributes.TryGetValue(key, out value); @@ -121,6 +144,7 @@ public bool TryGetValue(BalancerAttributesKey key, [MaybeNullWhe /// The value. public void Set(BalancerAttributesKey key, TValue value) { + ValidateReadOnly(); _attributes[key.Key] = value; } @@ -135,10 +159,55 @@ public void Set(BalancerAttributesKey key, TValue value) /// public bool Remove(BalancerAttributesKey key) { + ValidateReadOnly(); return _attributes.Remove(key.Key); } - internal string DebuggerToString() + private void ValidateReadOnly() + { + if (_readOnly) + { + throw new NotSupportedException("Collection is read-only."); + } + } + + internal static bool DeepEquals(BalancerAttributes? x, BalancerAttributes? y) + { + var xValues = x?._attributes; + var yValues = y?._attributes; + + if (ReferenceEquals(xValues, yValues)) + { + return true; + } + + if (xValues == null || yValues == null) + { + return false; + } + + if (xValues.Count != yValues.Count) + { + return false; + } + + foreach (var kvp in xValues) + { + if (!yValues.TryGetValue(kvp.Key, out var value)) + { + return false; + } + + if (!Equals(kvp.Value, value)) + { + return false; + } + } + + return true; + } + + private string DebuggerToString() { return $"Count = {_attributes.Count}"; } diff --git a/src/Grpc.Net.Client/Balancer/Internal/BalancerAddressEqualityComparer.cs b/src/Grpc.Net.Client/Balancer/Internal/BalancerAddressEqualityComparer.cs index a1d0f2fcf..60c62bf59 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/BalancerAddressEqualityComparer.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/BalancerAddressEqualityComparer.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -17,8 +17,6 @@ #endregion #if SUPPORT_LOAD_BALANCING -using System; -using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; namespace Grpc.Net.Client.Balancer.Internal; @@ -44,7 +42,7 @@ public bool Equals(BalancerAddress? x, BalancerAddress? y) return false; } - return true; + return BalancerAttributes.DeepEquals(x._attributes, y._attributes); } public int GetHashCode([DisallowNull] BalancerAddress obj) diff --git a/src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs b/src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs index 9b7659ef0..33bb1d8b5 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/BalancerHttpHandler.cs @@ -94,7 +94,7 @@ internal async ValueTask OnConnect(SocketsHttpConnectionContext context, } Debug.Assert(context.DnsEndPoint.Equals(currentAddress.EndPoint), "Context endpoint should equal address endpoint."); - return await subchannel.Transport.GetStreamAsync(currentAddress, cancellationToken).ConfigureAwait(false); + return await subchannel.Transport.GetStreamAsync(currentAddress.EndPoint, cancellationToken).ConfigureAwait(false); } #endif diff --git a/src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs b/src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs index 82ed0f9f3..5ef60ec9b 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs @@ -17,6 +17,7 @@ #endregion #if SUPPORT_LOAD_BALANCING +using System.Net; using Grpc.Shared; namespace Grpc.Net.Client.Balancer.Internal; @@ -28,11 +29,11 @@ namespace Grpc.Net.Client.Balancer.Internal; /// internal interface ISubchannelTransport : IDisposable { - BalancerAddress? CurrentAddress { get; } + DnsEndPoint? CurrentEndPoint { get; } TimeSpan? ConnectTimeout { get; } TransportStatus TransportStatus { get; } - ValueTask GetStreamAsync(BalancerAddress address, CancellationToken cancellationToken); + ValueTask GetStreamAsync(DnsEndPoint endPoint, CancellationToken cancellationToken); ValueTask TryConnectAsync(ConnectContext context); void Disconnect(); diff --git a/src/Grpc.Net.Client/Balancer/Internal/PassiveSubchannelTransport.cs b/src/Grpc.Net.Client/Balancer/Internal/PassiveSubchannelTransport.cs index 06c0ea5ba..979c8dc79 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/PassiveSubchannelTransport.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/PassiveSubchannelTransport.cs @@ -35,32 +35,32 @@ namespace Grpc.Net.Client.Balancer.Internal; internal class PassiveSubchannelTransport : ISubchannelTransport, IDisposable { private readonly Subchannel _subchannel; - private BalancerAddress? _currentAddress; + private DnsEndPoint? _currentEndPoint; public PassiveSubchannelTransport(Subchannel subchannel) { _subchannel = subchannel; } - public BalancerAddress? CurrentAddress => _currentAddress; + public DnsEndPoint? CurrentEndPoint => _currentEndPoint; public TimeSpan? ConnectTimeout { get; } public TransportStatus TransportStatus => TransportStatus.Passive; public void Disconnect() { - _currentAddress = null; + _currentEndPoint = null; _subchannel.UpdateConnectivityState(ConnectivityState.Idle, "Disconnected."); } public ValueTask TryConnectAsync(ConnectContext context) { Debug.Assert(_subchannel._addresses.Count == 1); - Debug.Assert(CurrentAddress == null); + Debug.Assert(CurrentEndPoint == null); var currentAddress = _subchannel._addresses[0]; _subchannel.UpdateConnectivityState(ConnectivityState.Connecting, "Passively connecting."); - _currentAddress = currentAddress; + _currentEndPoint = currentAddress.EndPoint; _subchannel.UpdateConnectivityState(ConnectivityState.Ready, "Passively connected."); return new ValueTask(ConnectResult.Success); @@ -68,10 +68,10 @@ public ValueTask TryConnectAsync(ConnectContext context) public void Dispose() { - _currentAddress = null; + _currentEndPoint = null; } - public ValueTask GetStreamAsync(BalancerAddress address, CancellationToken cancellationToken) + public ValueTask GetStreamAsync(DnsEndPoint endPoint, CancellationToken cancellationToken) { throw new NotSupportedException(); } diff --git a/src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs b/src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs index 6048874dd..b602843f2 100644 --- a/src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs +++ b/src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs @@ -43,7 +43,7 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi { private const int MaximumInitialSocketDataSize = 1024 * 16; internal static readonly TimeSpan SocketPingInterval = TimeSpan.FromSeconds(5); - internal readonly record struct ActiveStream(BalancerAddress Address, Socket Socket, Stream? Stream); + internal readonly record struct ActiveStream(DnsEndPoint EndPoint, Socket Socket, Stream? Stream); private readonly ILogger _logger; private readonly Subchannel _subchannel; @@ -55,11 +55,11 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi private int _lastEndPointIndex; internal Socket? _initialSocket; - private BalancerAddress? _initialSocketAddress; + private DnsEndPoint? _initialSocketEndPoint; private List>? _initialSocketData; private DateTime? _initialSocketCreatedTime; private bool _disposed; - private BalancerAddress? _currentAddress; + private DnsEndPoint? _currentEndPoint; public SocketConnectivitySubchannelTransport( Subchannel subchannel, @@ -80,7 +80,7 @@ public SocketConnectivitySubchannelTransport( } private object Lock => _subchannel.Lock; - public BalancerAddress? CurrentAddress => _currentAddress; + public DnsEndPoint? CurrentEndPoint => _currentEndPoint; public TimeSpan? ConnectTimeout { get; } public TransportStatus TransportStatus { @@ -137,16 +137,16 @@ private void DisconnectUnsynchronized() _initialSocket?.Dispose(); _initialSocket = null; - _initialSocketAddress = null; + _initialSocketEndPoint = null; _initialSocketData = null; _initialSocketCreatedTime = null; _lastEndPointIndex = 0; - _currentAddress = null; + _currentEndPoint = null; } public async ValueTask TryConnectAsync(ConnectContext context) { - Debug.Assert(CurrentAddress == null); + Debug.Assert(CurrentEndPoint == null); // Addresses could change while connecting. Make a copy of the subchannel's addresses. var addresses = _subchannel.GetAddresses(); @@ -157,7 +157,7 @@ public async ValueTask TryConnectAsync(ConnectContext context) for (var i = 0; i < addresses.Count; i++) { var currentIndex = (i + _lastEndPointIndex) % addresses.Count; - var currentAddress = addresses[currentIndex]; + var currentEndPoint = addresses[currentIndex].EndPoint; Socket socket; @@ -166,16 +166,16 @@ public async ValueTask TryConnectAsync(ConnectContext context) try { - SocketConnectivitySubchannelTransportLog.ConnectingSocket(_logger, _subchannel.Id, currentAddress); - await _socketConnect(socket, currentAddress.EndPoint, context.CancellationToken).ConfigureAwait(false); - SocketConnectivitySubchannelTransportLog.ConnectedSocket(_logger, _subchannel.Id, currentAddress); + SocketConnectivitySubchannelTransportLog.ConnectingSocket(_logger, _subchannel.Id, currentEndPoint); + await _socketConnect(socket, currentEndPoint, context.CancellationToken).ConfigureAwait(false); + SocketConnectivitySubchannelTransportLog.ConnectedSocket(_logger, _subchannel.Id, currentEndPoint); lock (Lock) { - _currentAddress = currentAddress; + _currentEndPoint = currentEndPoint; _lastEndPointIndex = currentIndex; _initialSocket = socket; - _initialSocketAddress = currentAddress; + _initialSocketEndPoint = currentEndPoint; _initialSocketData = null; _initialSocketCreatedTime = DateTime.UtcNow; @@ -190,7 +190,7 @@ public async ValueTask TryConnectAsync(ConnectContext context) } catch (Exception ex) { - SocketConnectivitySubchannelTransportLog.ErrorConnectingSocket(_logger, _subchannel.Id, currentAddress, ex); + SocketConnectivitySubchannelTransportLog.ErrorConnectingSocket(_logger, _subchannel.Id, currentEndPoint, ex); if (firstConnectionError == null) { @@ -250,7 +250,7 @@ private void OnCheckSocketConnection(object? state) try { Socket? socket; - BalancerAddress? socketAddress; + DnsEndPoint? socketEndpoint; var closeSocket = false; Exception? checkException = null; DateTime? socketCreatedTime; @@ -258,15 +258,15 @@ private void OnCheckSocketConnection(object? state) lock (Lock) { socket = _initialSocket; - socketAddress = _initialSocketAddress; + socketEndpoint = _initialSocketEndPoint; socketCreatedTime = _initialSocketCreatedTime; if (socket != null) { - CompatibilityHelpers.Assert(socketAddress != null); + CompatibilityHelpers.Assert(socketEndpoint != null); CompatibilityHelpers.Assert(socketCreatedTime != null); - closeSocket = ShouldCloseSocket(socket, socketAddress, ref _initialSocketData, out checkException); + closeSocket = ShouldCloseSocket(socket, socketEndpoint, ref _initialSocketData, out checkException); } } @@ -281,10 +281,10 @@ private void OnCheckSocketConnection(object? state) if (_initialSocket == socket) { - CompatibilityHelpers.Assert(socketAddress != null); + CompatibilityHelpers.Assert(socketEndpoint != null); CompatibilityHelpers.Assert(socketCreatedTime != null); - SocketConnectivitySubchannelTransportLog.ClosingUnusableSocket(_logger, _subchannel.Id, socketAddress, DateTime.UtcNow - socketCreatedTime.Value); + SocketConnectivitySubchannelTransportLog.ClosingUnusableSocket(_logger, _subchannel.Id, socketEndpoint, DateTime.UtcNow - socketCreatedTime.Value); DisconnectUnsynchronized(); } } @@ -308,32 +308,32 @@ private void OnCheckSocketConnection(object? state) } } - public async ValueTask GetStreamAsync(BalancerAddress address, CancellationToken cancellationToken) + public async ValueTask GetStreamAsync(DnsEndPoint endPoint, CancellationToken cancellationToken) { - SocketConnectivitySubchannelTransportLog.CreatingStream(_logger, _subchannel.Id, address); + SocketConnectivitySubchannelTransportLog.CreatingStream(_logger, _subchannel.Id, endPoint); Socket? socket = null; - BalancerAddress? socketAddress = null; + DnsEndPoint? socketEndPoint = null; List>? socketData = null; DateTime? socketCreatedTime = null; lock (Lock) { if (_initialSocket != null) { - var socketAddressMatch = Equals(_initialSocketAddress, address); + var socketEndPointMatch = Equals(_initialSocketEndPoint, endPoint); socket = _initialSocket; - socketAddress = _initialSocketAddress; + socketEndPoint = _initialSocketEndPoint; socketData = _initialSocketData; socketCreatedTime = _initialSocketCreatedTime; _initialSocket = null; - _initialSocketAddress = null; + _initialSocketEndPoint = null; _initialSocketData = null; _initialSocketCreatedTime = null; - // Double check the address matches the socket address and only use socket on match. + // Double check the endpoint matches the socket endpoint and only use socket on match. // Not sure if this is possible in practice, but better safe than sorry. - if (!socketAddressMatch) + if (!socketEndPointMatch) { socket.Dispose(); socket = null; @@ -351,12 +351,12 @@ public async ValueTask GetStreamAsync(BalancerAddress address, Cancellat if (_socketIdleTimeout != Timeout.InfiniteTimeSpan && DateTime.UtcNow > socketCreatedTime.Value.Add(_socketIdleTimeout)) { - SocketConnectivitySubchannelTransportLog.ClosingSocketFromIdleTimeoutOnCreateStream(_logger, _subchannel.Id, address, _socketIdleTimeout); + SocketConnectivitySubchannelTransportLog.ClosingSocketFromIdleTimeoutOnCreateStream(_logger, _subchannel.Id, endPoint, _socketIdleTimeout); closeSocket = true; } - else if (ShouldCloseSocket(socket, address, ref socketData, out _)) + else if (ShouldCloseSocket(socket, endPoint, ref socketData, out _)) { - SocketConnectivitySubchannelTransportLog.ClosingUnusableSocket(_logger, _subchannel.Id, address, DateTime.UtcNow - socketCreatedTime.Value); + SocketConnectivitySubchannelTransportLog.ClosingUnusableSocket(_logger, _subchannel.Id, endPoint, DateTime.UtcNow - socketCreatedTime.Value); closeSocket = true; } @@ -370,10 +370,10 @@ public async ValueTask GetStreamAsync(BalancerAddress address, Cancellat if (socket == null) { - SocketConnectivitySubchannelTransportLog.ConnectingOnCreateStream(_logger, _subchannel.Id, address); + SocketConnectivitySubchannelTransportLog.ConnectingOnCreateStream(_logger, _subchannel.Id, endPoint); socket = new Socket(SocketType.Stream, ProtocolType.Tcp) { NoDelay = true }; - await socket.ConnectAsync(address.EndPoint, cancellationToken).ConfigureAwait(false); + await socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false); } var networkStream = new NetworkStream(socket, ownsSocket: true); @@ -383,8 +383,8 @@ public async ValueTask GetStreamAsync(BalancerAddress address, Cancellat lock (Lock) { - _activeStreams.Add(new ActiveStream(address, socket, stream)); - SocketConnectivitySubchannelTransportLog.StreamCreated(_logger, _subchannel.Id, address, CalculateInitialSocketDataLength(socketData), _activeStreams.Count); + _activeStreams.Add(new ActiveStream(endPoint, socket, stream)); + SocketConnectivitySubchannelTransportLog.StreamCreated(_logger, _subchannel.Id, endPoint, CalculateInitialSocketDataLength(socketData), _activeStreams.Count); } return stream; @@ -394,13 +394,13 @@ public async ValueTask GetStreamAsync(BalancerAddress address, Cancellat /// Checks whether the socket is healthy. May read available data into the passed in buffer. /// Returns true if the socket should be closed. /// - private bool ShouldCloseSocket(Socket socket, BalancerAddress socketAddress, ref List>? socketData, out Exception? checkException) + private bool ShouldCloseSocket(Socket socket, DnsEndPoint socketEndPoint, ref List>? socketData, out Exception? checkException) { checkException = null; try { - SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketAddress); + SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketEndPoint); // Poll socket to check if it can be read from. Unfortunately this requires reading pending data. // The server might send data, e.g. HTTP/2 SETTINGS frame, so we need to read and cache it. @@ -411,7 +411,7 @@ private bool ShouldCloseSocket(Socket socket, BalancerAddress socketAddress, ref // We need to cache whatever we read so it isn't dropped. do { - if (PollSocket(socket, socketAddress)) + if (PollSocket(socket, socketEndPoint)) { // Polling socket reported an unhealthy state. return true; @@ -428,7 +428,7 @@ private bool ShouldCloseSocket(Socket socket, BalancerAddress socketAddress, ref throw new InvalidOperationException($"The server sent {serverDataAvailable} bytes to the client before a connection was established. Maximum allowed data exceeded."); } - SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketAddress, available); + SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketEndPoint, available); // Data is already available so this won't block. var buffer = new byte[available]; @@ -448,7 +448,7 @@ private bool ShouldCloseSocket(Socket socket, BalancerAddress socketAddress, ref catch (Exception ex) { checkException = ex; - SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketAddress, ex); + SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketEndPoint, ex); return true; } } @@ -458,7 +458,7 @@ private bool ShouldCloseSocket(Socket socket, BalancerAddress socketAddress, ref /// Shouldn't be used by itself as data needs to be consumed to accurately report the socket health. /// handles consuming data and getting the socket health. /// - private bool PollSocket(Socket socket, BalancerAddress address) + private bool PollSocket(Socket socket, DnsEndPoint endPoint) { // From https://github.com/dotnet/runtime/blob/3195fbbd82fdb7f132d6698591ba6489ad6dd8cf/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs#L158-L168 try @@ -471,14 +471,14 @@ private bool PollSocket(Socket socket, BalancerAddress address) } if (result) { - SocketConnectivitySubchannelTransportLog.SocketPollBadState(_logger, _subchannel.Id, address); + SocketConnectivitySubchannelTransportLog.SocketPollBadState(_logger, _subchannel.Id, endPoint); } return result; } catch (Exception ex) when (ex is SocketException || ex is ObjectDisposedException) { // Poll can throw when used on a closed socket. - SocketConnectivitySubchannelTransportLog.ErrorPollingSocket(_logger, _subchannel.Id, address, ex); + SocketConnectivitySubchannelTransportLog.ErrorPollingSocket(_logger, _subchannel.Id, endPoint, ex); return true; } } @@ -496,7 +496,7 @@ private void OnStreamDisposed(Stream streamWrapper) if (t.Stream == streamWrapper) { _activeStreams.RemoveAt(i); - SocketConnectivitySubchannelTransportLog.DisposingStream(_logger, _subchannel.Id, t.Address, _activeStreams.Count); + SocketConnectivitySubchannelTransportLog.DisposingStream(_logger, _subchannel.Id, t.EndPoint, _activeStreams.Count); // If the last active streams is removed then there is no active connection. disconnect = _activeStreams.Count == 0; @@ -543,29 +543,29 @@ public void Dispose() internal static class SocketConnectivitySubchannelTransportLog { - private static readonly Action _connectingSocket = - LoggerMessage.Define(LogLevel.Trace, new EventId(1, "ConnectingSocket"), "Subchannel id '{SubchannelId}' connecting socket to {Address}."); + private static readonly Action _connectingSocket = + LoggerMessage.Define(LogLevel.Trace, new EventId(1, "ConnectingSocket"), "Subchannel id '{SubchannelId}' connecting socket to {EndPoint}."); - private static readonly Action _connectedSocket = - LoggerMessage.Define(LogLevel.Debug, new EventId(2, "ConnectedSocket"), "Subchannel id '{SubchannelId}' connected to socket {Address}."); + private static readonly Action _connectedSocket = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "ConnectedSocket"), "Subchannel id '{SubchannelId}' connected to socket {EndPoint}."); - private static readonly Action _errorConnectingSocket = - LoggerMessage.Define(LogLevel.Debug, new EventId(3, "ErrorConnectingSocket"), "Subchannel id '{SubchannelId}' error connecting to socket {Address}."); + private static readonly Action _errorConnectingSocket = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, "ErrorConnectingSocket"), "Subchannel id '{SubchannelId}' error connecting to socket {EndPoint}."); - private static readonly Action _checkingSocket = - LoggerMessage.Define(LogLevel.Trace, new EventId(4, "CheckingSocket"), "Subchannel id '{SubchannelId}' checking socket {Address}."); + private static readonly Action _checkingSocket = + LoggerMessage.Define(LogLevel.Trace, new EventId(4, "CheckingSocket"), "Subchannel id '{SubchannelId}' checking socket {EndPoint}."); - private static readonly Action _errorCheckingSocket = - LoggerMessage.Define(LogLevel.Debug, new EventId(5, "ErrorCheckingSocket"), "Subchannel id '{SubchannelId}' error checking socket {Address}."); + private static readonly Action _errorCheckingSocket = + LoggerMessage.Define(LogLevel.Debug, new EventId(5, "ErrorCheckingSocket"), "Subchannel id '{SubchannelId}' error checking socket {EndPoint}."); private static readonly Action _errorSocketTimer = LoggerMessage.Define(LogLevel.Error, new EventId(6, "ErrorSocketTimer"), "Subchannel id '{SubchannelId}' unexpected error in check socket timer."); - private static readonly Action _creatingStream = - LoggerMessage.Define(LogLevel.Trace, new EventId(7, "CreatingStream"), "Subchannel id '{SubchannelId}' creating stream for {Address}."); + private static readonly Action _creatingStream = + LoggerMessage.Define(LogLevel.Trace, new EventId(7, "CreatingStream"), "Subchannel id '{SubchannelId}' creating stream for {EndPoint}."); - private static readonly Action _disposingStream = - LoggerMessage.Define(LogLevel.Trace, new EventId(8, "DisposingStream"), "Subchannel id '{SubchannelId}' disposing stream for {Address}. Transport has {ActiveStreams} active streams."); + private static readonly Action _disposingStream = + LoggerMessage.Define(LogLevel.Trace, new EventId(8, "DisposingStream"), "Subchannel id '{SubchannelId}' disposing stream for {EndPoint}. Transport has {ActiveStreams} active streams."); private static readonly Action _disposingTransport = LoggerMessage.Define(LogLevel.Trace, new EventId(9, "DisposingTransport"), "Subchannel id '{SubchannelId}' disposing transport."); @@ -573,50 +573,50 @@ internal static class SocketConnectivitySubchannelTransportLog private static readonly Action _errorOnDisposingStream = LoggerMessage.Define(LogLevel.Error, new EventId(10, "ErrorOnDisposingStream"), "Subchannel id '{SubchannelId}' unexpected error when reacting to transport stream dispose."); - private static readonly Action _connectingOnCreateStream = - LoggerMessage.Define(LogLevel.Trace, new EventId(11, "ConnectingOnCreateStream"), "Subchannel id '{SubchannelId}' doesn't have a connected socket available. Connecting new stream socket for {Address}."); + private static readonly Action _connectingOnCreateStream = + LoggerMessage.Define(LogLevel.Trace, new EventId(11, "ConnectingOnCreateStream"), "Subchannel id '{SubchannelId}' doesn't have a connected socket available. Connecting new stream socket for {EndPoint}."); - private static readonly Action _streamCreated = - LoggerMessage.Define(LogLevel.Trace, new EventId(12, "StreamCreated"), "Subchannel id '{SubchannelId}' created stream for {Address} with {BufferedBytes} buffered bytes. Transport has {ActiveStreams} active streams."); + private static readonly Action _streamCreated = + LoggerMessage.Define(LogLevel.Trace, new EventId(12, "StreamCreated"), "Subchannel id '{SubchannelId}' created stream for {EndPoint} with {BufferedBytes} buffered bytes. Transport has {ActiveStreams} active streams."); - private static readonly Action _errorPollingSocket = - LoggerMessage.Define(LogLevel.Debug, new EventId(13, "ErrorPollingSocket"), "Subchannel id '{SubchannelId}' error checking socket {Address}."); + private static readonly Action _errorPollingSocket = + LoggerMessage.Define(LogLevel.Debug, new EventId(13, "ErrorPollingSocket"), "Subchannel id '{SubchannelId}' error checking socket {EndPoint}."); - private static readonly Action _socketPollBadState = - LoggerMessage.Define(LogLevel.Debug, new EventId(14, "SocketPollBadState"), "Subchannel id '{SubchannelId}' socket {Address} is in a bad state and can't be used."); + private static readonly Action _socketPollBadState = + LoggerMessage.Define(LogLevel.Debug, new EventId(14, "SocketPollBadState"), "Subchannel id '{SubchannelId}' socket {EndPoint} is in a bad state and can't be used."); - private static readonly Action _socketReceivingAvailable = - LoggerMessage.Define(LogLevel.Trace, new EventId(15, "SocketReceivingAvailable"), "Subchannel id '{SubchannelId}' socket {Address} is receiving {ReadBytesAvailableCount} available bytes."); + private static readonly Action _socketReceivingAvailable = + LoggerMessage.Define(LogLevel.Trace, new EventId(15, "SocketReceivingAvailable"), "Subchannel id '{SubchannelId}' socket {EndPoint} is receiving {ReadBytesAvailableCount} available bytes."); - private static readonly Action _closingUnusableSocket = - LoggerMessage.Define(LogLevel.Debug, new EventId(16, "ClosingUnusableSocket"), "Subchannel id '{SubchannelId}' socket {Address} is being closed because it can't be used. Socket lifetime of {SocketLifetime}. The socket either can't receive data or it has received unexpected data."); + private static readonly Action _closingUnusableSocket = + LoggerMessage.Define(LogLevel.Debug, new EventId(16, "ClosingUnusableSocket"), "Subchannel id '{SubchannelId}' socket {EndPoint} is being closed because it can't be used. Socket lifetime of {SocketLifetime}. The socket either can't receive data or it has received unexpected data."); - private static readonly Action _closingSocketFromIdleTimeoutOnCreateStream = - LoggerMessage.Define(LogLevel.Debug, new EventId(16, "ClosingSocketFromIdleTimeoutOnCreateStream"), "Subchannel id '{SubchannelId}' socket {Address} is being closed because it exceeds the idle timeout of {SocketIdleTimeout}."); + private static readonly Action _closingSocketFromIdleTimeoutOnCreateStream = + LoggerMessage.Define(LogLevel.Debug, new EventId(16, "ClosingSocketFromIdleTimeoutOnCreateStream"), "Subchannel id '{SubchannelId}' socket {EndPoint} is being closed because it exceeds the idle timeout of {SocketIdleTimeout}."); - public static void ConnectingSocket(ILogger logger, string subchannelId, BalancerAddress address) + public static void ConnectingSocket(ILogger logger, string subchannelId, DnsEndPoint endPoint) { - _connectingSocket(logger, subchannelId, address, null); + _connectingSocket(logger, subchannelId, endPoint, null); } - public static void ConnectedSocket(ILogger logger, string subchannelId, BalancerAddress address) + public static void ConnectedSocket(ILogger logger, string subchannelId, DnsEndPoint endPoint) { - _connectedSocket(logger, subchannelId, address, null); + _connectedSocket(logger, subchannelId, endPoint, null); } - public static void ErrorConnectingSocket(ILogger logger, string subchannelId, BalancerAddress address, Exception ex) + public static void ErrorConnectingSocket(ILogger logger, string subchannelId, DnsEndPoint endPoint, Exception ex) { - _errorConnectingSocket(logger, subchannelId, address, ex); + _errorConnectingSocket(logger, subchannelId, endPoint, ex); } - public static void CheckingSocket(ILogger logger, string subchannelId, BalancerAddress address) + public static void CheckingSocket(ILogger logger, string subchannelId, DnsEndPoint endPoint) { - _checkingSocket(logger, subchannelId, address, null); + _checkingSocket(logger, subchannelId, endPoint, null); } - public static void ErrorCheckingSocket(ILogger logger, string subchannelId, BalancerAddress address, Exception ex) + public static void ErrorCheckingSocket(ILogger logger, string subchannelId, DnsEndPoint endPoint, Exception ex) { - _errorCheckingSocket(logger, subchannelId, address, ex); + _errorCheckingSocket(logger, subchannelId, endPoint, ex); } public static void ErrorSocketTimer(ILogger logger, string subchannelId, Exception ex) @@ -624,14 +624,14 @@ public static void ErrorSocketTimer(ILogger logger, string subchannelId, Excepti _errorSocketTimer(logger, subchannelId, ex); } - public static void CreatingStream(ILogger logger, string subchannelId, BalancerAddress address) + public static void CreatingStream(ILogger logger, string subchannelId, DnsEndPoint endPoint) { - _creatingStream(logger, subchannelId, address, null); + _creatingStream(logger, subchannelId, endPoint, null); } - public static void DisposingStream(ILogger logger, string subchannelId, BalancerAddress address, int activeStreams) + public static void DisposingStream(ILogger logger, string subchannelId, DnsEndPoint endPoint, int activeStreams) { - _disposingStream(logger, subchannelId, address, activeStreams, null); + _disposingStream(logger, subchannelId, endPoint, activeStreams, null); } public static void DisposingTransport(ILogger logger, string subchannelId) @@ -644,39 +644,39 @@ public static void ErrorOnDisposingStream(ILogger logger, string subchannelId, E _errorOnDisposingStream(logger, subchannelId, ex); } - public static void ConnectingOnCreateStream(ILogger logger, string subchannelId, BalancerAddress address) + public static void ConnectingOnCreateStream(ILogger logger, string subchannelId, DnsEndPoint endPoint) { - _connectingOnCreateStream(logger, subchannelId, address, null); + _connectingOnCreateStream(logger, subchannelId, endPoint, null); } - public static void StreamCreated(ILogger logger, string subchannelId, BalancerAddress address, int bufferedBytes, int activeStreams) + public static void StreamCreated(ILogger logger, string subchannelId, DnsEndPoint endPoint, int bufferedBytes, int activeStreams) { - _streamCreated(logger, subchannelId, address, bufferedBytes, activeStreams, null); + _streamCreated(logger, subchannelId, endPoint, bufferedBytes, activeStreams, null); } - public static void ErrorPollingSocket(ILogger logger, string subchannelId, BalancerAddress address, Exception ex) + public static void ErrorPollingSocket(ILogger logger, string subchannelId, DnsEndPoint endPoint, Exception ex) { - _errorPollingSocket(logger, subchannelId, address, ex); + _errorPollingSocket(logger, subchannelId, endPoint, ex); } - public static void SocketPollBadState(ILogger logger, string subchannelId, BalancerAddress address) + public static void SocketPollBadState(ILogger logger, string subchannelId, DnsEndPoint endPoint) { - _socketPollBadState(logger, subchannelId, address, null); + _socketPollBadState(logger, subchannelId, endPoint, null); } - public static void SocketReceivingAvailable(ILogger logger, string subchannelId, BalancerAddress address, int readBytesAvailableCount) + public static void SocketReceivingAvailable(ILogger logger, string subchannelId, DnsEndPoint endPoint, int readBytesAvailableCount) { - _socketReceivingAvailable(logger, subchannelId, address, readBytesAvailableCount, null); + _socketReceivingAvailable(logger, subchannelId, endPoint, readBytesAvailableCount, null); } - public static void ClosingUnusableSocket(ILogger logger, string subchannelId, BalancerAddress address, TimeSpan socketLifetime) + public static void ClosingUnusableSocket(ILogger logger, string subchannelId, DnsEndPoint endPoint, TimeSpan socketLifetime) { - _closingUnusableSocket(logger, subchannelId, address, socketLifetime, null); + _closingUnusableSocket(logger, subchannelId, endPoint, socketLifetime, null); } - public static void ClosingSocketFromIdleTimeoutOnCreateStream(ILogger logger, string subchannelId, BalancerAddress address, TimeSpan socketIdleTimeout) + public static void ClosingSocketFromIdleTimeoutOnCreateStream(ILogger logger, string subchannelId, DnsEndPoint endPoint, TimeSpan socketIdleTimeout) { - _closingSocketFromIdleTimeoutOnCreateStream(logger, subchannelId, address, socketIdleTimeout, null); + _closingSocketFromIdleTimeoutOnCreateStream(logger, subchannelId, endPoint, socketIdleTimeout, null); } } #endif diff --git a/src/Grpc.Net.Client/Balancer/Subchannel.cs b/src/Grpc.Net.Client/Balancer/Subchannel.cs index af394e0ef..12161e14a 100644 --- a/src/Grpc.Net.Client/Balancer/Subchannel.cs +++ b/src/Grpc.Net.Client/Balancer/Subchannel.cs @@ -17,6 +17,7 @@ #endregion #if SUPPORT_LOAD_BALANCING +using System.Net; using Grpc.Core; using Grpc.Net.Client.Balancer.Internal; using Microsoft.Extensions.Logging; @@ -62,7 +63,21 @@ public sealed class Subchannel : IDisposable /// /// Gets the current connected address. /// - public BalancerAddress? CurrentAddress => _transport.CurrentAddress; + public BalancerAddress? CurrentAddress + { + get + { + if (_transport.CurrentEndPoint is { } ep) + { + lock (Lock) + { + return GetAddressByEndpoint(_addresses, ep); + } + } + + return null; + } + } /// /// Gets the metadata attributes. @@ -173,10 +188,13 @@ public void UpdateAddresses(IReadOnlyList addresses) case ConnectivityState.Ready: // Transport uses the subchannel lock but take copy in an abundance of caution. var currentAddress = CurrentAddress; - if (currentAddress != null && !_addresses.Contains(currentAddress)) + if (currentAddress != null) { - SubchannelLog.ConnectedAddressNotInUpdatedAddresses(_logger, Id, currentAddress); - requireReconnect = true; + if (GetAddressByEndpoint(_addresses, currentAddress.EndPoint) != null) + { + SubchannelLog.ConnectedAddressNotInUpdatedAddresses(_logger, Id, currentAddress); + requireReconnect = true; + } } break; case ConnectivityState.Shutdown: @@ -361,7 +379,7 @@ internal bool UpdateConnectivityState(ConnectivityState state, string successDet { return UpdateConnectivityState(state, new Status(StatusCode.OK, successDetail)); } - + internal bool UpdateConnectivityState(ConnectivityState state, Status status) { lock (Lock) @@ -402,6 +420,19 @@ internal void RaiseStateChanged(ConnectivityState state, Status status) } } + private static BalancerAddress? GetAddressByEndpoint(List addresses, DnsEndPoint endPoint) + { + foreach (var a in addresses) + { + if (a.EndPoint.Equals(endPoint)) + { + return a; + } + } + + return null; + } + /// public override string ToString() { diff --git a/src/Grpc.Net.Client/Balancer/SubchannelsLoadBalancer.cs b/src/Grpc.Net.Client/Balancer/SubchannelsLoadBalancer.cs index d90044f63..c152796df 100644 --- a/src/Grpc.Net.Client/Balancer/SubchannelsLoadBalancer.cs +++ b/src/Grpc.Net.Client/Balancer/SubchannelsLoadBalancer.cs @@ -140,6 +140,14 @@ public override void UpdateChannelState(ChannelState state) // remaining in this collection at the end will be disposed. currentSubchannels.RemoveAt(i.Value); + // Check if address attributes have changed. If they have then update the subchannel address. + // The new subchannel address has the same endpoint so the connection isn't impacted. + if (!BalancerAddressEqualityComparer.Instance.Equals(address, newOrCurrentSubchannel.Address)) + { + newOrCurrentSubchannel.Subchannel.UpdateAddresses(new[] { address }); + newOrCurrentSubchannel = new AddressSubchannel(newOrCurrentSubchannel.Subchannel, address); + } + SubchannelLog.SubchannelPreserved(_logger, newOrCurrentSubchannel.Subchannel.Id, address); } else diff --git a/test/FunctionalTests/Balancer/ConnectionTests.cs b/test/FunctionalTests/Balancer/ConnectionTests.cs index f2ec1828a..b89a1772c 100644 --- a/test/FunctionalTests/Balancer/ConnectionTests.cs +++ b/test/FunctionalTests/Balancer/ConnectionTests.cs @@ -346,7 +346,7 @@ await TestHelpers.AssertIsTrueRetryAsync(() => }, "Wait for connections to start."); foreach (var t in activeStreams) { - Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50051), t.Address.EndPoint); + Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50051), t.EndPoint); } // Act @@ -367,7 +367,7 @@ await TestHelpers.AssertIsTrueRetryAsync(() => activeStreams = transport.GetActiveStreams(); return activeStreams.Count == 11; }, "Wait for connections to start."); - Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50051), activeStreams[activeStreams.Count - 1].Address.EndPoint); + Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50051), activeStreams[activeStreams.Count - 1].EndPoint); tcs.SetResult(null); @@ -407,7 +407,7 @@ await TestHelpers.AssertIsTrueRetryAsync(() => activeStreams = transport.GetActiveStreams(); Assert.AreEqual(1, activeStreams.Count); - Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50052), activeStreams[0].Address.EndPoint); + Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50052), activeStreams[0].EndPoint); } #if NET7_0_OR_GREATER diff --git a/test/FunctionalTests/Balancer/PickFirstBalancerTests.cs b/test/FunctionalTests/Balancer/PickFirstBalancerTests.cs index 611f68349..d257214c0 100644 --- a/test/FunctionalTests/Balancer/PickFirstBalancerTests.cs +++ b/test/FunctionalTests/Balancer/PickFirstBalancerTests.cs @@ -371,7 +371,7 @@ async Task UnaryMethod(HelloRequest request, ServerCallContext conte Assert.GreaterOrEqual(activeStreams.Count, 2); foreach (var stream in activeStreams) { - Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50051), stream.Address.EndPoint); + Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50051), stream.EndPoint); } tcs.SetResult(null); @@ -385,7 +385,7 @@ async Task UnaryMethod(HelloRequest request, ServerCallContext conte await TestHelpers.AssertIsTrueRetryAsync(() => { activeStreams = transport.GetActiveStreams(); - Logger.LogInformation($"Current active stream addresses: {string.Join(", ", activeStreams.Select(s => s.Address))}"); + Logger.LogInformation($"Current active stream addresses: {string.Join(", ", activeStreams.Select(s => s.EndPoint))}"); return activeStreams.Count == 0; }, "Active streams removed.", Logger).DefaultTimeout(); @@ -395,7 +395,7 @@ await TestHelpers.AssertIsTrueRetryAsync(() => activeStreams = transport.GetActiveStreams(); Assert.AreEqual(1, activeStreams.Count); - Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50052), activeStreams[0].Address.EndPoint); + Assert.AreEqual(new DnsEndPoint("127.0.0.1", 50052), activeStreams[0].EndPoint); } [Test] diff --git a/test/FunctionalTests/Balancer/RoundRobinBalancerTests.cs b/test/FunctionalTests/Balancer/RoundRobinBalancerTests.cs index d8c8af667..815e19409 100644 --- a/test/FunctionalTests/Balancer/RoundRobinBalancerTests.cs +++ b/test/FunctionalTests/Balancer/RoundRobinBalancerTests.cs @@ -381,8 +381,8 @@ Task UnaryMethod(HelloRequest request, ServerCallContext context) var activeStreams = ((SocketConnectivitySubchannelTransport)disposedSubchannel.Transport).GetActiveStreams(); Assert.AreEqual(1, activeStreams.Count); - Assert.AreEqual("127.0.0.1", activeStreams[0].Address.EndPoint.Host); - Assert.AreEqual(50051, activeStreams[0].Address.EndPoint.Port); + Assert.AreEqual("127.0.0.1", activeStreams[0].EndPoint.Host); + Assert.AreEqual(50051, activeStreams[0].EndPoint.Port); // Wait until connected to new endpoint Subchannel? newSubchannel = null; @@ -406,15 +406,17 @@ Task UnaryMethod(HelloRequest request, ServerCallContext context) Assert.AreEqual("127.0.0.1:50052", host!); // Disposed subchannel stream removed when endpoint disposed. - activeStreams = ((SocketConnectivitySubchannelTransport)disposedSubchannel.Transport).GetActiveStreams(); - Assert.AreEqual(0, activeStreams.Count); - Assert.IsNull(((SocketConnectivitySubchannelTransport)disposedSubchannel.Transport)._initialSocket); + await TestHelpers.AssertIsTrueRetryAsync(() => + { + var disposedTransport = (SocketConnectivitySubchannelTransport)disposedSubchannel.Transport; + return disposedTransport.GetActiveStreams().Count == 0 && disposedTransport._initialSocket == null; + }, "Wait for SocketsHttpHandler to react to server closing streams.").DefaultTimeout(); // New subchannel stream created with request. activeStreams = ((SocketConnectivitySubchannelTransport)newSubchannel.Transport).GetActiveStreams(); Assert.AreEqual(1, activeStreams.Count); - Assert.AreEqual("127.0.0.1", activeStreams[0].Address.EndPoint.Host); - Assert.AreEqual(50052, activeStreams[0].Address.EndPoint.Port); + Assert.AreEqual("127.0.0.1", activeStreams[0].EndPoint.Host); + Assert.AreEqual(50052, activeStreams[0].EndPoint.Port); Assert.IsNull(((SocketConnectivitySubchannelTransport)disposedSubchannel.Transport)._initialSocket); } diff --git a/test/Grpc.Net.Client.Tests/Balancer/RoundRobinBalancerTests.cs b/test/Grpc.Net.Client.Tests/Balancer/RoundRobinBalancerTests.cs index 745ede8c8..692a875a4 100644 --- a/test/Grpc.Net.Client.Tests/Balancer/RoundRobinBalancerTests.cs +++ b/test/Grpc.Net.Client.Tests/Balancer/RoundRobinBalancerTests.cs @@ -328,7 +328,8 @@ public async Task HasSubchannels_ResolverRefresh_MatchingSubchannelUnchanged() resolver.UpdateAddresses(new List { new BalancerAddress("localhost", 80), - new BalancerAddress("localhost", 81) + new BalancerAddress("localhost", 81), + new BalancerAddress("localhost", 82) }); // Act @@ -340,31 +341,44 @@ public async Task HasSubchannels_ResolverRefresh_MatchingSubchannelUnchanged() await connectTask.DefaultTimeout(); var subchannels = channel.ConnectionManager.GetSubchannels(); - Assert.AreEqual(2, subchannels.Count); + Assert.AreEqual(3, subchannels.Count); Assert.AreEqual(1, subchannels[0]._addresses.Count); Assert.AreEqual(new DnsEndPoint("localhost", 80), subchannels[0]._addresses[0].EndPoint); Assert.AreEqual(1, subchannels[1]._addresses.Count); Assert.AreEqual(new DnsEndPoint("localhost", 81), subchannels[1]._addresses[0].EndPoint); + Assert.AreEqual(1, subchannels[2]._addresses.Count); + Assert.AreEqual(new DnsEndPoint("localhost", 82), subchannels[2]._addresses[0].EndPoint); - // Preserved because port 81 is in both refresh results - var preservedSubchannel = subchannels[1]; + // Preserved because port 81, 82 is in both refresh results + var preservedSubchannel1 = subchannels[1]; + var preservedSubchannel2 = subchannels[2]; + + var address2 = new BalancerAddress("localhost", 82); + address2.Attributes.Set(new BalancerAttributesKey("test"), 1); resolver.UpdateAddresses(new List { new BalancerAddress("localhost", 81), - new BalancerAddress("localhost", 82) + address2, + new BalancerAddress("localhost", 83) }); subchannels = channel.ConnectionManager.GetSubchannels(); - Assert.AreEqual(2, subchannels.Count); + Assert.AreEqual(3, subchannels.Count); Assert.AreEqual(1, subchannels[0]._addresses.Count); Assert.AreEqual(new DnsEndPoint("localhost", 81), subchannels[0]._addresses[0].EndPoint); Assert.AreEqual(1, subchannels[1]._addresses.Count); Assert.AreEqual(new DnsEndPoint("localhost", 82), subchannels[1]._addresses[0].EndPoint); + Assert.AreEqual(1, subchannels[2]._addresses.Count); + Assert.AreEqual(new DnsEndPoint("localhost", 83), subchannels[2]._addresses[0].EndPoint); + + Assert.AreSame(preservedSubchannel1, subchannels[0]); + Assert.AreSame(preservedSubchannel2, subchannels[1]); - Assert.AreSame(preservedSubchannel, subchannels[0]); + // Test that the channel's address was updated with new attribute with new attributes. + Assert.AreSame(preservedSubchannel2.CurrentAddress, address2); } } #endif diff --git a/test/Grpc.Net.Client.Tests/Infrastructure/Balancer/TestSubChannelTransport.cs b/test/Grpc.Net.Client.Tests/Infrastructure/Balancer/TestSubChannelTransport.cs index 67543a5fb..171b9bfc2 100644 --- a/test/Grpc.Net.Client.Tests/Infrastructure/Balancer/TestSubChannelTransport.cs +++ b/test/Grpc.Net.Client.Tests/Infrastructure/Balancer/TestSubChannelTransport.cs @@ -38,7 +38,7 @@ internal class TestSubchannelTransport : ISubchannelTransport public Subchannel Subchannel { get; } - public BalancerAddress? CurrentAddress { get; private set; } + public DnsEndPoint? CurrentEndPoint { get; private set; } public TimeSpan? ConnectTimeout => _factory.ConnectTimeout; public TransportStatus TransportStatus => TransportStatus.Passive; @@ -62,14 +62,14 @@ public void Dispose() { } - public ValueTask GetStreamAsync(BalancerAddress address, CancellationToken cancellationToken) + public ValueTask GetStreamAsync(DnsEndPoint endPoint, CancellationToken cancellationToken) { return new ValueTask(new MemoryStream()); } public void Disconnect() { - CurrentAddress = null; + CurrentEndPoint = null; Subchannel.UpdateConnectivityState(ConnectivityState.Idle, "Disconnected."); } @@ -83,7 +83,7 @@ public async { var (newState, connectResult) = await (_onTryConnect?.Invoke(context.CancellationToken) ?? Task.FromResult(new TryConnectResult(ConnectivityState.Ready))); - CurrentAddress = Subchannel._addresses[0]; + CurrentEndPoint = Subchannel._addresses[0].EndPoint; var newStatus = newState == ConnectivityState.TransientFailure ? new Status(StatusCode.Internal, "") : Status.DefaultSuccess; Subchannel.UpdateConnectivityState(newState, newStatus);