From 17a56d97629d291325de8cdb53c9a400cda4ec1c Mon Sep 17 00:00:00 2001 From: Liangying Wei Date: Fri, 26 Sep 2025 12:06:15 +0800 Subject: [PATCH 1/2] Fix the shutdown timeout hung --- .../ServerConnections/ServiceConnection.cs | 4 +- .../Auth/LocalTokenProvider.cs | 5 +- .../MicrosoftEntraTokenProvider.cs | 5 +- .../Interfaces/IAccessTokenProvider.cs | 5 +- .../Internal/WebSocketsTransport.cs | 2 +- .../ServiceConnectionBase.cs | 89 ++++++++++--------- .../ServiceConnectionContainerBase.cs | 19 +++- .../HubHost/ServiceHubDispatcher.cs | 2 +- .../ServerConnections/ServiceConnection.cs | 4 +- .../TestClasses/TestServiceConnectionProxy.cs | 3 +- .../TestClasses/TestServiceConnection.cs | 4 +- 11 files changed, 84 insertions(+), 58 deletions(-) diff --git a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs index 7b0354a38..1db3545cb 100644 --- a/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR.AspNet/ServerConnections/ServiceConnection.cs @@ -76,9 +76,9 @@ public override Task CloseClientConnections(CancellationToken token) throw new NotSupportedException(); } - protected override Task CreateConnection(string target = null) + protected override Task CreateConnection(string target = null, CancellationToken cancellationToken = default) { - return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target); + return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target, cancellationToken); } protected override Task DisposeConnection(ConnectionContext connection) diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/LocalTokenProvider.cs b/src/Microsoft.Azure.SignalR.Common/Auth/LocalTokenProvider.cs index cb31715f0..baf965b96 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/LocalTokenProvider.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/LocalTokenProvider.cs @@ -1,9 +1,10 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; using System.Collections.Generic; using System.Security.Claims; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.Azure.SignalR; @@ -34,5 +35,5 @@ public LocalTokenProvider( _tokenLifetime = tokenLifetime ?? Constants.Periods.DefaultAccessTokenLifetime; } - public Task ProvideAsync() => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm); + public Task ProvideAsync(CancellationToken cancellationToken) => _accessKey.GenerateAccessTokenAsync(_audience, _claims, _tokenLifetime, _algorithm, cancellationToken); } diff --git a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraTokenProvider.cs b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraTokenProvider.cs index cf0077523..5a01997ff 100644 --- a/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraTokenProvider.cs +++ b/src/Microsoft.Azure.SignalR.Common/Auth/MicrosoftEntra/MicrosoftEntraTokenProvider.cs @@ -1,7 +1,8 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.Azure.SignalR; @@ -15,5 +16,5 @@ public MicrosoftEntraTokenProvider(MicrosoftEntraAccessKey accessKey) _accessKey = accessKey ?? throw new ArgumentNullException(nameof(accessKey)); } - public Task ProvideAsync() => _accessKey.GetMicrosoftEntraTokenAsync(); + public Task ProvideAsync(CancellationToken cancellationToken) => _accessKey.GetMicrosoftEntraTokenAsync(cancellationToken); } diff --git a/src/Microsoft.Azure.SignalR.Common/Interfaces/IAccessTokenProvider.cs b/src/Microsoft.Azure.SignalR.Common/Interfaces/IAccessTokenProvider.cs index 367397e64..906f4955f 100644 --- a/src/Microsoft.Azure.SignalR.Common/Interfaces/IAccessTokenProvider.cs +++ b/src/Microsoft.Azure.SignalR.Common/Interfaces/IAccessTokenProvider.cs @@ -1,11 +1,12 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. +using System.Threading; using System.Threading.Tasks; namespace Microsoft.Azure.SignalR; internal interface IAccessTokenProvider { - Task ProvideAsync(); + Task ProvideAsync(CancellationToken cancellationToken = default); } diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/Internal/WebSocketsTransport.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/Internal/WebSocketsTransport.cs index 49a6a4d09..639d132d6 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/Internal/WebSocketsTransport.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/Internal/WebSocketsTransport.cs @@ -121,7 +121,7 @@ public async Task StartAsync(Uri url, CancellationToken cancellationToken = defa // We don't need to capture to a local because we never change this delegate. if (_accessTokenProvider != null) { - accessToken = await _accessTokenProvider.ProvideAsync(); + accessToken = await _accessTokenProvider.ProvideAsync(cancellationToken); if (!string.IsNullOrEmpty(accessToken)) { _webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}"); diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs index 94cae2ce1..698bc6b1c 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionBase.cs @@ -44,6 +44,8 @@ internal abstract partial class ServiceConnectionBase : IServiceConnection private readonly TaskCompletionSource _serviceConnectionOfflineTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly CancellationTokenSource _connectionStartCts = new(); + private readonly ServiceConnectionType _connectionType; private readonly IServiceMessageHandler _serviceMessageHandler; @@ -157,63 +159,66 @@ public async Task StartAsync(string target = null) } Status = ServiceConnectionStatus.Connecting; - - var connection = await EstablishConnectionAsync(target); - if (connection != null) + try { - _connectionContext = connection; - Status = ServiceConnectionStatus.Connected; - _serviceConnectionStartTcs.TrySetResult(true); - try + var connection = await EstablishConnectionAsync(target, _connectionStartCts.Token); + if (connection != null) { - TimerAwaitable syncTimer = null; + _connectionContext = connection; + Status = ServiceConnectionStatus.Connected; + _serviceConnectionStartTcs.TrySetResult(true); try { - if (HubEndpoint != null && HubEndpoint.AccessKey is MicrosoftEntraAccessKey key) + TimerAwaitable syncTimer = null; + try { - syncTimer = new TimerAwaitable(TimeSpan.Zero, DefaultSyncAzureIdentityInterval); - _ = UpdateAzureIdentityAsync(key, syncTimer); + if (HubEndpoint != null && HubEndpoint.AccessKey is MicrosoftEntraAccessKey key) + { + syncTimer = new TimerAwaitable(TimeSpan.Zero, DefaultSyncAzureIdentityInterval); + _ = UpdateAzureIdentityAsync(key, syncTimer); + } + await ProcessIncomingAsync(connection); } - await ProcessIncomingAsync(connection); - } - finally - { - // mark the status as Disconnected so that no one will write to this connection anymore - Status = ServiceConnectionStatus.Disconnected; - syncTimer?.Stop(); + finally + { + // mark the status as Disconnected so that no one will write to this connection anymore + Status = ServiceConnectionStatus.Disconnected; + syncTimer?.Stop(); - // when ProcessIncoming completes, clean up the connection + // when ProcessIncoming completes, clean up the connection - // TODO: Never cleanup connections unless Service asks us to do that - // Current implementation is based on assumption that Service will drop clients - // if server connection fails. - await CleanupClientConnections(); + // TODO: Never cleanup connections unless Service asks us to do that + // Current implementation is based on assumption that Service will drop clients + // if server connection fails. + await CleanupClientConnections(); + } } - } - catch (Exception ex) - { - Log.ConnectionDropped(Logger, _endpointName, ConnectionId, ex); - } - finally - { - // wait until all the connections are cleaned up to close the outgoing pipe - // Don't allow write anymore when the connection is disconnected - await _writeLock.WaitAsync(); - try + catch (Exception ex) { - // close the underlying connection - await DisposeConnection(connection); + Log.ConnectionDropped(Logger, _endpointName, ConnectionId, ex); } finally { - _writeLock.Release(); + // wait until all the connections are cleaned up to close the outgoing pipe + // Don't allow write anymore when the connection is disconnected + await _writeLock.WaitAsync(); + try + { + // close the underlying connection + await DisposeConnection(connection); + } + finally + { + _writeLock.Release(); + } } } } - else + finally { Status = ServiceConnectionStatus.Disconnected; _serviceConnectionStartTcs.TrySetResult(false); + _serviceConnectionOfflineTcs.TrySetResult(false); } } @@ -221,6 +226,8 @@ public Task StopAsync() { try { + // to avoid the connection hung in connecting state + _connectionStartCts.Cancel(); _connectionContext?.Transport.Input.CancelPendingRead(); } catch (Exception ex) @@ -277,7 +284,7 @@ public virtual async Task SafeWriteAsync(ServiceMessage serviceMessage) public abstract bool TryRemoveClientConnection(string connectionId, out IClientConnection connection); - protected abstract Task CreateConnection(string target = null); + protected abstract Task CreateConnection(string target = null, CancellationToken cancellationToken = default); protected abstract Task DisposeConnection(ConnectionContext connection); @@ -492,11 +499,11 @@ private Task OnFlowControlMessageAsync(ConnectionFlowControlMessage flowControlM throw new NotImplementedException($"Unsupported connection type: {flowControlMessage.ConnectionType}"); } - private async Task EstablishConnectionAsync(string target) + private async Task EstablishConnectionAsync(string target, CancellationToken cancellationToken) { try { - var connectionContext = await CreateConnection(target); + var connectionContext = await CreateConnection(target, cancellationToken); try { if (await HandshakeAsync(connectionContext)) diff --git a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs index 5592ea956..8c68ae72f 100644 --- a/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs +++ b/src/Microsoft.Azure.SignalR.Common/ServiceConnections/ServiceConnectionContainerBase.cs @@ -318,7 +318,7 @@ public async IAsyncEnumerable> ListConnectionsInGro public virtual Task OfflineAsync(GracefulShutdownMode mode, CancellationToken token) { _terminated = true; - return Task.WhenAll(ServiceConnections.Select(c => RemoveConnectionAsync(c, mode, token))); + return Task.WhenAll(ServiceConnections.Select(c => RemoveConnectionFromServiceAsync(c, mode, token))); } public virtual Task CloseClientConnections(CancellationToken token) @@ -476,8 +476,23 @@ protected virtual ServiceConnectionStatus GetStatus() : ServiceConnectionStatus.Disconnected; } - protected async Task RemoveConnectionAsync(IServiceConnection c, GracefulShutdownMode mode, CancellationToken token) + /// + /// TODO: this logic sounds more fit into the serviceConnection class + /// + /// The service connection instance + /// The graceful shutdown mode + /// The cancellation token + /// + protected async Task RemoveConnectionFromServiceAsync(IServiceConnection c, GracefulShutdownMode mode, CancellationToken token) { + if (c.Status != ServiceConnectionStatus.Connected) + { + // if the connection is not yet connected + // we stop the connection in case it is connecting + // otherwise ConnectionOfflineTask should be set + await c.StopAsync(); + return; + } var retry = 0; while (retry < MaxRetryRemoveSeverConnection && !token.IsCancellationRequested) { diff --git a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs index fde7312bc..6c6018f50 100644 --- a/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs +++ b/src/Microsoft.Azure.SignalR/HubHost/ServiceHubDispatcher.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; diff --git a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs index fd040f520..80b2c3f73 100644 --- a/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs +++ b/src/Microsoft.Azure.SignalR/ServerConnections/ServiceConnection.cs @@ -129,9 +129,9 @@ public override async Task CloseClientConnections(CancellationToken token) } } - protected override Task CreateConnection(string target = null) + protected override Task CreateConnection(string target = null, CancellationToken cancellationToken = default) { - return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target); + return _connectionFactory.ConnectAsync(HubEndpoint, TransferFormat.Binary, ConnectionId, target, cancellationToken); } protected override Task DisposeConnection(ConnectionContext connection) diff --git a/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestServiceConnectionProxy.cs b/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestServiceConnectionProxy.cs index 9ce9b3e9c..d9904659b 100644 --- a/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestServiceConnectionProxy.cs +++ b/test/Microsoft.Azure.SignalR.AspNet.Tests/TestClasses/TestServiceConnectionProxy.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Concurrent; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; @@ -83,7 +84,7 @@ public override async Task SafeWriteAsync(ServiceMessage serviceMessage) return result; } - protected override async Task CreateConnection(string target = null) + protected override async Task CreateConnection(string target = null, CancellationToken cancellationToken = default) { TestConnectionContext = await base.CreateConnection() as TestConnectionContext; diff --git a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs index 8c8801c5a..aa493eeed 100644 --- a/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs +++ b/test/Microsoft.Azure.SignalR.Tests.Common/TestClasses/TestServiceConnection.cs @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft. All rights reserved. +// Copyright (c) Microsoft. All rights reserved. // Licensed under the MIT license. See LICENSE file in the project root for full license information. using System; @@ -93,7 +93,7 @@ protected override Task CleanupClientConnections(string? fromInstanceId = null) return Task.CompletedTask; } - protected override Task CreateConnection(string? target = null) + protected override Task CreateConnection(string? target = null, CancellationToken cancellationToken = default) { var pipeOptions = new PipeOptions(); var duplex = DuplexPipe.CreateConnectionPair(pipeOptions, pipeOptions); From 22df81766d997c74b0060393f5b1125563e419d6 Mon Sep 17 00:00:00 2001 From: Liangying Wei Date: Fri, 26 Sep 2025 13:29:45 +0800 Subject: [PATCH 2/2] Fix tests --- .../ServiceConnectionContainerBaseTests.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs index 63e26c79c..539646786 100644 --- a/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs +++ b/test/Microsoft.Azure.SignalR.Tests/ServiceConnectionContainerBaseTests.cs @@ -257,6 +257,7 @@ public Task StartAsync(string target = null) public Task StopAsync() { + _offline.TrySetResult(true); return Task.CompletedTask; } @@ -264,11 +265,11 @@ public Task WriteAsync(ServiceMessage serviceMessage) { if (RuntimeServicePingMessage.IsFin(serviceMessage)) { - _offline.SetResult(true); + _offline.TrySetResult(true); } if (RuntimeServicePingMessage.IsGetServers(serviceMessage)) { - _serversPing.SetResult(true); + _serversPing.TrySetResult(true); } return Task.CompletedTask; }